Skip to content

Commit 1a909f3

Browse files
committed
impl-syrk
1 parent e3a9b5e commit 1a909f3

File tree

13 files changed

+521
-63
lines changed

13 files changed

+521
-63
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
* Added `--target-cuda[=ARCH]` option to replace the deprecated `--target=cuda`, allowing users to build for CUDA devices with optional architecture selection using [CodePlay oneAPI plug-in](https://developer.codeplay.com/products/oneapi/nvidia/home/) [#2478](https://github.com/IntelPython/dpnp/pull/2478)
1212
* Added several new `pre-commit` rules, including protection against direct commits to master/maintenance branches [#2500](https://github.com/IntelPython/dpnp/pull/2500)
13+
* Added a new backend routine `syrk` from oneMKL to perform symmetric rank-k update which is used for a specialized matrix multiplication where the result is a symmetric matrix [2509](https://github.com/IntelPython/dpnp/pull/2509)
1314

1415
### Changed
1516

dpnp/backend/extensions/blas/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ set(_module_src
3030
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
3131
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp
3232
${CMAKE_CURRENT_SOURCE_DIR}/gemv.cpp
33+
${CMAKE_CURRENT_SOURCE_DIR}/syrk.cpp
3334
)
3435

3536
pybind11_add_module(${python_module_name} MODULE ${_module_src})

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "dotu.hpp"
3737
#include "gemm.hpp"
3838
#include "gemv.hpp"
39+
#include "syrk.hpp"
3940

4041
namespace blas_ns = dpnp::extensions::blas;
4142
namespace py = pybind11;
@@ -48,6 +49,7 @@ void init_dispatch_vectors_tables(void)
4849
blas_ns::init_gemm_batch_dispatch_table();
4950
blas_ns::init_gemm_dispatch_table();
5051
blas_ns::init_gemv_dispatch_vector();
52+
blas_ns::init_syrk_dispatch_vector();
5153
}
5254

5355
static dot_impl_fn_ptr_t dot_dispatch_vector[dpctl_td_ns::num_types];
@@ -73,7 +75,7 @@ PYBIND11_MODULE(_blas_impl, m)
7375
};
7476

7577
m.def("_dot", dot_pyapi,
76-
"Call `dot` from OneMKL BLAS library to compute "
78+
"Call `dot` from oneMKL BLAS library to compute "
7779
"the dot product of two real-valued vectors.",
7880
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
7981
py::arg("result"), py::arg("depends") = py::list());
@@ -91,7 +93,7 @@ PYBIND11_MODULE(_blas_impl, m)
9193
};
9294

9395
m.def("_dotc", dotc_pyapi,
94-
"Call `dotc` from OneMKL BLAS library to compute "
96+
"Call `dotc` from oneMKL BLAS library to compute "
9597
"the dot product of two complex vectors, "
9698
"conjugating the first vector.",
9799
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
@@ -110,37 +112,45 @@ PYBIND11_MODULE(_blas_impl, m)
110112
};
111113

112114
m.def("_dotu", dotu_pyapi,
113-
"Call `dotu` from OneMKL BLAS library to compute "
115+
"Call `dotu` from oneMKL BLAS library to compute "
114116
"the dot product of two complex vectors.",
115117
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
116118
py::arg("result"), py::arg("depends") = py::list());
117119
}
118120

119121
{
120122
m.def("_gemm", &blas_ns::gemm,
121-
"Call `gemm` from OneMKL BLAS library to compute "
123+
"Call `gemm` from oneMKL BLAS library to compute "
122124
"the matrix-matrix product with 2-D matrices.",
123125
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
124126
py::arg("resultC"), py::arg("depends") = py::list());
125127
}
126128

127129
{
128130
m.def("_gemm_batch", &blas_ns::gemm_batch,
129-
"Call `gemm_batch` from OneMKL BLAS library to compute "
131+
"Call `gemm_batch` from oneMKL BLAS library to compute "
130132
"the matrix-matrix product for a batch of 2-D matrices.",
131133
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("matrixB"),
132134
py::arg("resultC"), py::arg("depends") = py::list());
133135
}
134136

135137
{
136138
m.def("_gemv", &blas_ns::gemv,
137-
"Call `gemv` from OneMKL BLAS library to compute "
139+
"Call `gemv` from oneMKL BLAS library to compute "
138140
"the matrix-vector product with a general matrix.",
139141
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"),
140142
py::arg("vectorY"), py::arg("transpose"),
141143
py::arg("depends") = py::list());
142144
}
143145

146+
{
147+
m.def("_syrk", &blas_ns::syrk,
148+
"Call `syrk` from oneMKL BLAS library to compute "
149+
"the matrix-vector product with a general matrix.",
150+
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("resultC"),
151+
py::arg("depends") = py::list());
152+
}
153+
144154
{
145155
m.def(
146156
"_using_onemath",

dpnp/backend/extensions/blas/gemm.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ static sycl::event gemm_impl(sycl::queue &exec_q,
129129
Tab(1), // Scaling factor for the product of matrices A and B.
130130
a, // Pointer to matrix A.
131131
lda, // Leading dimension of matrix A, which is the
132-
// stride between successive rows (for row major
133-
// layout).
132+
// stride between successive rows (for row major layout).
134133
b, // Pointer to matrix B.
135134
ldb, // Leading dimension of matrix B, similar to lda.
136135
Tab(0), // Scaling factor for matrix C.
@@ -168,7 +167,8 @@ std::tuple<sycl::event, sycl::event, bool>
168167
const int resultC_nd = resultC.get_ndim();
169168

170169
if ((matrixA_nd != 2) || (matrixB_nd != 2) || (resultC_nd != 2)) {
171-
throw py::value_error("Input matrices must be two-dimensional.");
170+
throw py::value_error(
171+
"Input and output matrices must be two-dimensional.");
172172
}
173173

174174
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
@@ -286,6 +286,8 @@ std::tuple<sycl::event, sycl::event, bool>
286286
}
287287
}
288288
else {
289+
// both A and B are f_contig so using column-major gemm and
290+
// no transpose is needed
289291
transA = oneapi::mkl::transpose::N;
290292
transB = oneapi::mkl::transpose::N;
291293
lda = m;

dpnp/backend/extensions/blas/gemv.cpp

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,7 @@ static sycl::event gemv_impl(sycl::queue &exec_q,
118118
T(1), // Scaling factor for the matrix-vector product.
119119
a, // Pointer to the input matrix A.
120120
lda, // Leading dimension of matrix A, which is the
121-
// stride between successive rows (for row major
122-
// layout).
121+
// stride between successive rows (for row major layout).
123122
x, // Pointer to the input vector x.
124123
incx, // The stride of vector x.
125124
T(0), // Scaling factor for vector y.
@@ -190,6 +189,26 @@ std::pair<sycl::event, sycl::event>
190189
const py::ssize_t *a_shape = matrixA.get_shape_raw();
191190
const py::ssize_t *x_shape = vectorX.get_shape_raw();
192191
const py::ssize_t *y_shape = vectorY.get_shape_raw();
192+
if (transpose) {
193+
if (a_shape[0] != x_shape[0]) {
194+
throw py::value_error("The number of rows in A must be equal to "
195+
"the number of elements in X.");
196+
}
197+
if (a_shape[1] != y_shape[0]) {
198+
throw py::value_error("The number of columns in A must be equal to "
199+
"the number of elements in Y.");
200+
}
201+
}
202+
else {
203+
if (a_shape[1] != x_shape[0]) {
204+
throw py::value_error("The number of columns in A must be equal to "
205+
"the number of elements in X.");
206+
}
207+
if (a_shape[0] != y_shape[0]) {
208+
throw py::value_error("The number of rows in A must be equal to "
209+
"the number of elements in Y.");
210+
}
211+
}
193212

194213
oneapi::mkl::transpose transA;
195214
std::size_t src_nelems;
@@ -243,27 +262,6 @@ std::pair<sycl::event, sycl::event>
243262
}
244263
#endif // USE_ONEMATH_CUBLAS
245264

246-
if (transpose) {
247-
if (a_shape[0] != x_shape[0]) {
248-
throw py::value_error("The number of rows in A must be equal to "
249-
"the number of elements in X.");
250-
}
251-
if (a_shape[1] != y_shape[0]) {
252-
throw py::value_error("The number of columns in A must be equal to "
253-
"the number of elements in Y.");
254-
}
255-
}
256-
else {
257-
if (a_shape[1] != x_shape[0]) {
258-
throw py::value_error("The number of columns in A must be equal to "
259-
"the number of elements in X.");
260-
}
261-
if (a_shape[0] != y_shape[0]) {
262-
throw py::value_error("The number of rows in A must be equal to "
263-
"the number of elements in Y.");
264-
}
265-
}
266-
267265
const std::int64_t lda = is_row_major ? n : m;
268266
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(vectorY);
269267
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(vectorY,
@@ -287,7 +285,7 @@ std::pair<sycl::event, sycl::event>
287285
"Types of input arrays and result array are mismatched.");
288286
}
289287

290-
char *a_typeless_ptr = matrixA.get_data();
288+
const char *a_typeless_ptr = matrixA.get_data();
291289
char *x_typeless_ptr = vectorX.get_data();
292290
char *y_typeless_ptr = vectorY.get_data();
293291

dpnp/backend/extensions/blas/gemv.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,4 @@ extern std::pair<sycl::event, sycl::event>
4141
const std::vector<sycl::event> &depends);
4242

4343
extern void init_gemv_dispatch_vector(void);
44-
extern void init_gemv_batch_dispatch_vector(void);
4544
} // namespace dpnp::extensions::blas

0 commit comments

Comments
 (0)