|
50 | 50 | ]
|
51 | 51 |
|
52 | 52 |
|
53 |
| -def _call_syrk(x1, x2): |
54 |
| - """ |
55 |
| - Check to see if `syrk` can be called instead of `gemm`. |
56 |
| -
|
57 |
| - It is assumed that x1 and x2 are usm_ndarray objects. These arrays have |
58 |
| - already been validated to be 2-dimensional and contiguous. Therefore, this |
59 |
| - function only verifies the following: Both arrays reference the same |
60 |
| - memory. The number of rows in x1 equals the number of columns in x2. If one |
61 |
| - array is C-contiguous, the other must be F-contiguous. |
62 |
| -
|
63 |
| - """ |
64 |
| - call_syrk = False |
65 |
| - if ( |
66 |
| - x1._pointer == x2._pointer |
67 |
| - and x1.shape[0] == x2.shape[1] |
68 |
| - and x1.flags.c_contiguous != x2.flags.c_contiguous |
69 |
| - and x1.flags.f_contiguous != x2.flags.f_contiguous |
70 |
| - ): |
71 |
| - call_syrk = True |
72 |
| - |
73 |
| - return call_syrk |
74 |
| - |
75 |
| - |
76 | 53 | def _compute_res_dtype(*arrays, dtype=None, out=None, casting="no"):
|
77 | 54 | """
|
78 | 55 | Determines the output array data type.
|
@@ -541,6 +518,29 @@ def _get_signature(func):
|
541 | 518 | return signature, distinct_core
|
542 | 519 |
|
543 | 520 |
|
| 521 | +def _is_syrk_compatible(x1, x2): |
| 522 | + """ |
| 523 | + Check to see if `syrk` can be called instead of `gemm`. |
| 524 | + Input arrays have already been validated to be 2-dimensional. |
| 525 | +
|
| 526 | + """ |
| 527 | + # Must share data (same base buffer) |
| 528 | + if dpnp.get_usm_ndarray(x1)._pointer != dpnp.get_usm_ndarray(x2)._pointer: |
| 529 | + return False |
| 530 | + |
| 531 | + # Result must be square |
| 532 | + if x1.shape[0] != x2.shape[1]: |
| 533 | + return False |
| 534 | + |
| 535 | + # Strides must match transpose pattern |
| 536 | + x1_strides = x1.strides |
| 537 | + x2_strides = x2.strides |
| 538 | + if x1_strides[0] != x2_strides[1] or x1_strides[1] != x2_strides[0]: |
| 539 | + return False |
| 540 | + |
| 541 | + return True |
| 542 | + |
| 543 | + |
544 | 544 | def _shape_error(shape1, shape2, func, err_msg):
|
545 | 545 | """Validate the shapes of input and output arrays."""
|
546 | 546 |
|
@@ -983,6 +983,11 @@ def dpnp_multiplication(
|
983 | 983 | x1 = dpnp.reshape(x1, x1_shape[-2:])
|
984 | 984 | x2 = dpnp.reshape(x2, x2_shape[-2:])
|
985 | 985 | res_shape = (x1_shape[-2], x2_shape[-1])
|
| 986 | + if _is_syrk_compatible(x1, x2): |
| 987 | + call_flag = "syrk" |
| 988 | + res_dtype_orig = res_dtype |
| 989 | + if dpnp.issubdtype(res_dtype, dpnp.integer): |
| 990 | + res_dtype = dpnp.default_float_type(x1.device) |
986 | 991 | elif x1_base_is_1D:
|
987 | 992 | # TODO: implement gemv_batch to use it here with transpose
|
988 | 993 | call_flag = "gemm_batch"
|
@@ -1088,21 +1093,17 @@ def dpnp_multiplication(
|
1088 | 1093 | depends=_manager.submitted_events,
|
1089 | 1094 | )
|
1090 | 1095 | _manager.add_event_pair(ht_ev, gemv_ev)
|
| 1096 | + elif call_flag == "syrk": |
| 1097 | + _manager = dpu.SequentialOrderManager[exec_q] |
| 1098 | + ht_ev, gemv_ev = bi._syrk( |
| 1099 | + exec_q, |
| 1100 | + dpnp.get_usm_ndarray(x1), |
| 1101 | + dpnp.get_usm_ndarray(result), |
| 1102 | + depends=_manager.submitted_events, |
| 1103 | + ) |
| 1104 | + _manager.add_event_pair(ht_ev, gemv_ev) |
1091 | 1105 | elif call_flag == "gemm":
|
1092 |
| - x1_usm = dpnp.get_usm_ndarray(x1) |
1093 |
| - x2_usm = dpnp.get_usm_ndarray(x2) |
1094 |
| - call_syrk = _call_syrk(x1_usm, x2_usm) |
1095 |
| - if call_syrk: |
1096 |
| - _manager = dpu.SequentialOrderManager[exec_q] |
1097 |
| - ht_ev, gemv_ev = bi._syrk( |
1098 |
| - exec_q, |
1099 |
| - x1_usm, |
1100 |
| - dpnp.get_usm_ndarray(result), |
1101 |
| - depends=_manager.submitted_events, |
1102 |
| - ) |
1103 |
| - _manager.add_event_pair(ht_ev, gemv_ev) |
1104 |
| - else: |
1105 |
| - result = _gemm_matmul(exec_q, x1_usm, x2_usm, result) |
| 1106 | + result = _gemm_matmul(exec_q, x1, x2, result) |
1106 | 1107 | else:
|
1107 | 1108 | assert call_flag == "gemm_batch"
|
1108 | 1109 | result = _gemm_batch_matmul(exec_q, x1, x2, result)
|
@@ -1130,6 +1131,9 @@ def dpnp_multiplication(
|
1130 | 1131 | elif res_shape != result_shape:
|
1131 | 1132 | result = dpnp.reshape(result, result_shape)
|
1132 | 1133 |
|
| 1134 | + if call_flag == "syrk" and res_dtype_orig != res_dtype: |
| 1135 | + result = result.astype(res_dtype_orig) |
| 1136 | + |
1133 | 1137 | if out is None:
|
1134 | 1138 | if axes is not None:
|
1135 | 1139 | # Move the data back to the appropriate axes of the result array
|
|
0 commit comments