Skip to content

Commit c76f58d

Browse files
committed
use syrk for int dtypes when possible
1 parent a387235 commit c76f58d

File tree

2 files changed

+69
-39
lines changed

2 files changed

+69
-39
lines changed

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -50,29 +50,6 @@
5050
]
5151

5252

53-
def _call_syrk(x1, x2):
54-
"""
55-
Check to see if `syrk` can be called instead of `gemm`.
56-
57-
It is assumed that x1 and x2 are usm_ndarray objects. These arrays have
58-
already been validated to be 2-dimensional and contiguous. Therefore, this
59-
function only verifies the following: Both arrays reference the same
60-
memory. The number of rows in x1 equals the number of columns in x2. If one
61-
array is C-contiguous, the other must be F-contiguous.
62-
63-
"""
64-
call_syrk = False
65-
if (
66-
x1._pointer == x2._pointer
67-
and x1.shape[0] == x2.shape[1]
68-
and x1.flags.c_contiguous != x2.flags.c_contiguous
69-
and x1.flags.f_contiguous != x2.flags.f_contiguous
70-
):
71-
call_syrk = True
72-
73-
return call_syrk
74-
75-
7653
def _compute_res_dtype(*arrays, dtype=None, out=None, casting="no"):
7754
"""
7855
Determines the output array data type.
@@ -541,6 +518,29 @@ def _get_signature(func):
541518
return signature, distinct_core
542519

543520

521+
def _is_syrk_compatible(x1, x2):
522+
"""
523+
Check to see if `syrk` can be called instead of `gemm`.
524+
Input arrays have already been validated to be 2-dimensional.
525+
526+
"""
527+
# Must share data (same base buffer)
528+
if dpnp.get_usm_ndarray(x1)._pointer != dpnp.get_usm_ndarray(x2)._pointer:
529+
return False
530+
531+
# Result must be square
532+
if x1.shape[0] != x2.shape[1]:
533+
return False
534+
535+
# Strides must match transpose pattern
536+
x1_strides = x1.strides
537+
x2_strides = x2.strides
538+
if x1_strides[0] != x2_strides[1] or x1_strides[1] != x2_strides[0]:
539+
return False
540+
541+
return True
542+
543+
544544
def _shape_error(shape1, shape2, func, err_msg):
545545
"""Validate the shapes of input and output arrays."""
546546

@@ -983,6 +983,11 @@ def dpnp_multiplication(
983983
x1 = dpnp.reshape(x1, x1_shape[-2:])
984984
x2 = dpnp.reshape(x2, x2_shape[-2:])
985985
res_shape = (x1_shape[-2], x2_shape[-1])
986+
if _is_syrk_compatible(x1, x2):
987+
call_flag = "syrk"
988+
res_dtype_orig = res_dtype
989+
if dpnp.issubdtype(res_dtype, dpnp.integer):
990+
res_dtype = dpnp.default_float_type(x1.device)
986991
elif x1_base_is_1D:
987992
# TODO: implement gemv_batch to use it here with transpose
988993
call_flag = "gemm_batch"
@@ -1088,21 +1093,17 @@ def dpnp_multiplication(
10881093
depends=_manager.submitted_events,
10891094
)
10901095
_manager.add_event_pair(ht_ev, gemv_ev)
1096+
elif call_flag == "syrk":
1097+
_manager = dpu.SequentialOrderManager[exec_q]
1098+
ht_ev, gemv_ev = bi._syrk(
1099+
exec_q,
1100+
dpnp.get_usm_ndarray(x1),
1101+
dpnp.get_usm_ndarray(result),
1102+
depends=_manager.submitted_events,
1103+
)
1104+
_manager.add_event_pair(ht_ev, gemv_ev)
10911105
elif call_flag == "gemm":
1092-
x1_usm = dpnp.get_usm_ndarray(x1)
1093-
x2_usm = dpnp.get_usm_ndarray(x2)
1094-
call_syrk = _call_syrk(x1_usm, x2_usm)
1095-
if call_syrk:
1096-
_manager = dpu.SequentialOrderManager[exec_q]
1097-
ht_ev, gemv_ev = bi._syrk(
1098-
exec_q,
1099-
x1_usm,
1100-
dpnp.get_usm_ndarray(result),
1101-
depends=_manager.submitted_events,
1102-
)
1103-
_manager.add_event_pair(ht_ev, gemv_ev)
1104-
else:
1105-
result = _gemm_matmul(exec_q, x1_usm, x2_usm, result)
1106+
result = _gemm_matmul(exec_q, x1, x2, result)
11061107
else:
11071108
assert call_flag == "gemm_batch"
11081109
result = _gemm_batch_matmul(exec_q, x1, x2, result)
@@ -1130,6 +1131,9 @@ def dpnp_multiplication(
11301131
elif res_shape != result_shape:
11311132
result = dpnp.reshape(result, result_shape)
11321133

1134+
if call_flag == "syrk" and res_dtype_orig != res_dtype:
1135+
result = result.astype(res_dtype_orig)
1136+
11331137
if out is None:
11341138
if axes is not None:
11351139
# Move the data back to the appropriate axes of the result array

dpnp/tests/test_product.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
assert_dtype_allclose,
1313
generate_random_numpy_array,
1414
get_all_dtypes,
15-
get_float_complex_dtypes,
1615
numpy_version,
1716
)
1817
from .third_party.cupy import testing
@@ -1184,7 +1183,7 @@ def test_special_case(self, dt_out, shape1, shape2):
11841183
result = dpnp.matmul(ia, ib, out=iout)
11851184
assert_dtype_allclose(result, expected)
11861185

1187-
@pytest.mark.parametrize("dt", get_float_complex_dtypes())
1186+
@pytest.mark.parametrize("dt", get_all_dtypes())
11881187
def test_syrk(self, dt):
11891188
a = generate_random_numpy_array((6, 9), dtype=dt)
11901189
ia = dpnp.array(a)
@@ -1202,6 +1201,21 @@ def test_syrk(self, dt):
12021201
expected = a.T @ a
12031202
assert_dtype_allclose(result, expected)
12041203

1204+
@pytest.mark.parametrize("dt", [dpnp.int32, dpnp.float32])
1205+
def test_syrk_strided(self, dt):
1206+
a = generate_random_numpy_array((20, 30), dtype=dt)
1207+
ia = dpnp.array(a)
1208+
a = a[::2, ::2]
1209+
ia = ia[::2, ::2]
1210+
1211+
result = dpnp.matmul(ia, ia.mT)
1212+
expected = numpy.matmul(a, a.T)
1213+
assert_dtype_allclose(result, expected)
1214+
1215+
result = ia.mT @ ia
1216+
expected = a.T @ a
1217+
assert_dtype_allclose(result, expected)
1218+
12051219
@pytest.mark.parametrize(
12061220
"order, out_order",
12071221
[("C", "C"), ("C", "F"), ("F", "C"), ("F", "F")],
@@ -1226,6 +1240,18 @@ def test_syrk_order(self, order):
12261240
result = dpnp.matmul(ia, ia.mT)
12271241
assert_dtype_allclose(result, expected)
12281242

1243+
# added for coverage
1244+
def test_not_syrk(self):
1245+
a = generate_random_numpy_array((20, 20), low=-5, high=5)
1246+
ia = dpnp.array(a)
1247+
1248+
# Result must be square
1249+
b = a.mT[:, ::2]
1250+
ib = ia.mT[:, ::2]
1251+
expected = numpy.matmul(a, b)
1252+
result = dpnp.matmul(ia, ib)
1253+
assert_dtype_allclose(result, expected)
1254+
12291255
def test_bool(self):
12301256
a = generate_random_numpy_array((3, 4), dtype=dpnp.bool)
12311257
b = generate_random_numpy_array((4, 5), dtype=dpnp.bool)

0 commit comments

Comments
 (0)