Skip to content

Commit 479fa96

Browse files
aartbikcliffburdick
authored andcommitted
allow for incoming/outgoing transformations on SpMM (#884)
* allow for incoming/outgoing transformations on SpMM also some auto formatting * removed stray line
1 parent d75dcf7 commit 479fa96

File tree

2 files changed

+112
-51
lines changed

2 files changed

+112
-51
lines changed

include/matx/transforms/matmul/matmul_cusparse.h

Lines changed: 67 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
// Redistribution and use in source and binary forms, with or without
88
// modification, are permitted provided that the following conditions are met:
99
//
10-
// 1. Redistributions of source code must retain the above copyright notice, this
11-
// list of conditions and the following disclaimer.
10+
// 1. Redistributions of source code must retain the above copyright notice,
11+
// this list of conditions and the following disclaimer.
1212
//
1313
// 2. Redistributions in binary form must reproduce the above copyright notice,
1414
// this list of conditions and the following disclaimer in the documentation
@@ -20,14 +20,15 @@
2020
//
2121
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
2222
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23-
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24-
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25-
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26-
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27-
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28-
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29-
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30-
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
25+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
26+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
27+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
28+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31+
// POSSIBILITY OF SUCH DAMAGE.
3132
/////////////////////////////////////////////////////////////////////////////////
3233

3334
#pragma once
@@ -97,10 +98,9 @@ class MatMulCUSPARSEHandle_t {
9798
std::is_same_v<TC, cuda::std::complex<double>>) {
9899
salpha_ = {alpha, 0};
99100
sbeta_ = {beta, 0};
100-
}
101-
else if constexpr (std::is_same_v<TC, float> ||
102-
std::is_same_v<TC, double>) {
103-
salpha_ = alpha;;
101+
} else if constexpr (std::is_same_v<TC, float> ||
102+
std::is_same_v<TC, double>) {
103+
salpha_ = alpha;
104104
sbeta_ = beta;
105105
} else {
106106
MATX_THROW(matxNotSupported, "SpMM currently only supports uniform FP");
@@ -150,9 +150,9 @@ class MatMulCUSPARSEHandle_t {
150150
// Allocate a workspace for SpMM.
151151
const cusparseSpMMAlg_t algo = CUSPARSE_SPMM_ALG_DEFAULT;
152152
const cudaDataType comptp = dtc; // TODO: support separate comp type?!
153-
ret = cusparseSpMM_bufferSize(handle_, params_.opA, params_.opB,
154-
&salpha_, matA_, matB_, &sbeta_,
155-
matC_, comptp, algo, &workspaceSize_);
153+
ret = cusparseSpMM_bufferSize(handle_, params_.opA, params_.opB, &salpha_,
154+
matA_, matB_, &sbeta_, matC_, comptp, algo,
155+
&workspaceSize_);
156156
MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError);
157157
if (workspaceSize_) {
158158
matxAlloc((void **)&workspace_, workspaceSize_, MATX_DEVICE_MEMORY);
@@ -203,8 +203,8 @@ class MatMulCUSPARSEHandle_t {
203203
const cusparseSpMMAlg_t algo = CUSPARSE_SPMM_ALG_DEFAULT;
204204
const cudaDataType comptp = MatXTypeToCudaType<TC>(); // TODO: see above
205205
[[maybe_unused]] cusparseStatus_t ret =
206-
cusparseSpMM(handle_, params_.opA, params_.opB, &salpha_, matA_,
207-
matB_, &sbeta_, matC_, comptp, algo, workspace_);
206+
cusparseSpMM(handle_, params_.opA, params_.opB, &salpha_, matA_, matB_,
207+
&sbeta_, matC_, comptp, algo, workspace_);
208208
MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError);
209209
}
210210

@@ -255,49 +255,69 @@ using gemm_cusparse_cache_t =
255255
std::unordered_map<MatMulCUSPARSEParams_t, std::any,
256256
MatMulCUSPARSEParamsKeyHash, MatMulCUSPARSEParamsKeyEq>;
257257

258+
template <typename Op>
259+
__MATX_INLINE__ auto getCuSparseSupportedTensor(const Op &in,
260+
cudaStream_t stream) {
261+
const auto func = [&]() {
262+
if constexpr (is_tensor_view_v<Op>)
263+
return in.Stride(Op::Rank() - 1) == 1;
264+
return true;
265+
};
266+
return GetSupportedTensor(in, func, MATX_ASYNC_DEVICE_MEMORY, stream);
267+
}
268+
258269
} // end namespace detail
259270

260271
template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
261-
void sparse_matmul_impl(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b,
262-
const cudaExecutor &exec, float alpha = 1.0,
263-
float beta = 0.0) {
272+
void sparse_matmul_impl(TensorTypeC &C, const TensorTypeA &a,
273+
const TensorTypeB &B, const cudaExecutor &exec,
274+
float alpha = 1.0, float beta = 0.0) {
264275
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
265276
const auto stream = exec.getStream();
266277

267-
using TA = typename TensorTypeA::value_type;
268-
using TB = typename TensorTypeB::value_type;
269-
using TC = typename TensorTypeC::value_type;
278+
// Transform into supported form.
279+
auto b = getCuSparseSupportedTensor(B, stream);
280+
auto c = getCuSparseSupportedTensor(C, stream);
281+
if (!is_matx_transform_op<TensorTypeB>() && !b.isSameView(B)) {
282+
(b = B).run(stream);
283+
}
284+
285+
using atype = TensorTypeA;
286+
using btype = decltype(b);
287+
using ctype = decltype(c);
288+
289+
using TA = typename atype::value_type;
290+
using TB = typename btype::value_type;
291+
using TC = typename ctype::value_type;
270292

271-
static constexpr int RANKA = TensorTypeA::Rank();
272-
static constexpr int RANKB = TensorTypeB::Rank();
273-
static constexpr int RANKC = TensorTypeC::Rank();
293+
static constexpr int RANKA = atype::Rank();
294+
static constexpr int RANKB = btype::Rank();
295+
static constexpr int RANKC = ctype::Rank();
274296

275297
// Restrictions.
276298
static_assert(RANKA == 2 && RANKB == 2 && RANKC == 2,
277299
"tensors must have rank-2");
278-
static_assert(std::is_same_v<TC, TA> &&
279-
std::is_same_v<TC, TB>,
300+
static_assert(std::is_same_v<TC, TA> && std::is_same_v<TC, TB>,
280301
"tensors must have the same data type");
281302
// TODO: allow MIXED-PRECISION computation!
282-
static_assert(std::is_same_v<TC, float> ||
283-
std::is_same_v<TC, double> ||
284-
std::is_same_v<TC, cuda::std::complex<float>> ||
285-
std::is_same_v<TC, cuda::std::complex<double>>,
303+
static_assert(std::is_same_v<TC, float> || std::is_same_v<TC, double> ||
304+
std::is_same_v<TC, cuda::std::complex<float>> ||
305+
std::is_same_v<TC, cuda::std::complex<double>>,
286306
"unsupported data type");
287-
MATX_ASSERT(
288-
a.Size(RANKA - 1) == b.Size(RANKB - 2) &&
289-
c.Size(RANKC - 1) == b.Size(RANKB - 1) &&
290-
c.Size(RANKC - 2) == a.Size(RANKA - 2), matxInvalidSize);
291-
MATX_ASSERT(b.Stride(RANKB - 1) == 1 &&
292-
c.Stride(RANKC - 1) == 1, matxInvalidParameter);
307+
MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 2) &&
308+
c.Size(RANKC - 1) == b.Size(RANKB - 1) &&
309+
c.Size(RANKC - 2) == a.Size(RANKA - 2),
310+
matxInvalidSize);
311+
MATX_ASSERT(b.Stride(RANKB - 1) == 1 && c.Stride(RANKC - 1) == 1,
312+
matxInvalidParameter);
293313

294314
// Get parameters required by these tensors (for caching).
295315
auto params =
296-
detail::MatMulCUSPARSEHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>::GetGemmParams(
316+
detail::MatMulCUSPARSEHandle_t<ctype, atype, btype>::GetGemmParams(
297317
c, a, b, stream, alpha, beta);
298318

299319
// Lookup and cache.
300-
using cache_val_type = detail::MatMulCUSPARSEHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>;
320+
using cache_val_type = detail::MatMulCUSPARSEHandle_t<ctype, atype, btype>;
301321
detail::GetCache().LookupAndExec<detail::gemm_cusparse_cache_t>(
302322
detail::GetCacheIdFromType<detail::gemm_cusparse_cache_t>(), params,
303323
[&]() {
@@ -306,6 +326,11 @@ void sparse_matmul_impl(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB
306326
[&](std::shared_ptr<cache_val_type> cache_type) {
307327
cache_type->Exec(c, a, b);
308328
});
329+
330+
// Copy transformed output back.
331+
if (!c.isSameView(C)) {
332+
(C = c).run(stream);
333+
}
309334
}
310335

311336
} // end namespace matx

test/00_sparse/Matmul.cu

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
//
2121
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
2222
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23-
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24-
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25-
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26-
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27-
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28-
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29-
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30-
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
25+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
26+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
27+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
28+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31+
// POSSIBILITY OF SUCH DAMAGE.
3132
/////////////////////////////////////////////////////////////////////////////////
3233

3334
#include "assert.h"
@@ -135,7 +136,24 @@ TYPED_TEST(MatmulSparseTestsAll, MatmulCOO) {
135136
else {
136137
ASSERT_NEAR(O(i, j), E(i, j), this->thresh);
137138
}
139+
}
140+
}
141+
142+
// Allow transforming output.
143+
auto TO = make_tensor<TestType>({n, m});
144+
(transpose(TO) = matmul(S, B)).run(exec);
138145

146+
// Verify result.
147+
exec.sync();
148+
for (index_t i = 0; i < m; i++) {
149+
for (index_t j = 0; j < n; j++) {
150+
if constexpr (is_complex_v<TestType>) {
151+
ASSERT_NEAR(TO(j, i).real(), E(i, j).real(), this->thresh);
152+
ASSERT_NEAR(TO(j, i).imag(), E(i,j ).imag(), this->thresh);
153+
}
154+
else {
155+
ASSERT_NEAR(TO(j, i), E(i, j), this->thresh);
156+
}
139157
}
140158
}
141159

@@ -177,7 +195,6 @@ TYPED_TEST(MatmulSparseTestsAll, MatmulCSR) {
177195
else {
178196
ASSERT_NEAR(O(i, j), E(i, j), this->thresh);
179197
}
180-
181198
}
182199
}
183200

@@ -219,7 +236,26 @@ TYPED_TEST(MatmulSparseTestsAll, MatmulCSC) {
219236
else {
220237
ASSERT_NEAR(O(i, j), E(i, j), this->thresh);
221238
}
239+
}
240+
}
241+
242+
// Allow dense computations (pre-/post-matmul).
243+
TestType C3 = static_cast<TestType>(3);
244+
TestType C5 = static_cast<TestType>(5);
245+
(B = (B - C3)).run(exec);
246+
(O = matmul(S, B + C3) + C5).run(exec);
222247

248+
// Verify result.
249+
exec.sync();
250+
for (index_t i = 0; i < m; i++) {
251+
for (index_t j = 0; j < n; j++) {
252+
if constexpr (is_complex_v<TestType>) {
253+
ASSERT_NEAR((O(i, j) - C5).real(), E(i, j).real(), this->thresh);
254+
ASSERT_NEAR((O(i, j) - C5).imag(), E(i,j ).imag(), this->thresh);
255+
}
256+
else {
257+
ASSERT_NEAR(O(i, j) - C5, E(i, j), this->thresh);
258+
}
223259
}
224260
}
225261

0 commit comments

Comments
 (0)