77// Redistribution and use in source and binary forms, with or without
88// modification, are permitted provided that the following conditions are met:
99//
10- // 1. Redistributions of source code must retain the above copyright notice, this
11- // list of conditions and the following disclaimer.
10+ // 1. Redistributions of source code must retain the above copyright notice,
11+ // this list of conditions and the following disclaimer.
1212//
1313// 2. Redistributions in binary form must reproduce the above copyright notice,
1414// this list of conditions and the following disclaimer in the documentation
2020//
2121// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
2222// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23- // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24- // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25- // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26- // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27- // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28- // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29- // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30- // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23+ // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24+ // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
25+ // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
26+ // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
27+ // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
28+ // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29+ // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30+ // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
31+ // POSSIBILITY OF SUCH DAMAGE.
3132// ///////////////////////////////////////////////////////////////////////////////
3233
3334#pragma once
@@ -97,10 +98,9 @@ class MatMulCUSPARSEHandle_t {
9798 std::is_same_v<TC, cuda::std::complex <double >>) {
9899 salpha_ = {alpha, 0 };
99100 sbeta_ = {beta, 0 };
100- }
101- else if constexpr (std::is_same_v<TC, float > ||
102- std::is_same_v<TC, double >) {
103- salpha_ = alpha;;
101+ } else if constexpr (std::is_same_v<TC, float > ||
102+ std::is_same_v<TC, double >) {
103+ salpha_ = alpha;
104104 sbeta_ = beta;
105105 } else {
106106 MATX_THROW (matxNotSupported, " SpMM currently only supports uniform FP" );
@@ -150,9 +150,9 @@ class MatMulCUSPARSEHandle_t {
150150 // Allocate a workspace for SpMM.
151151 const cusparseSpMMAlg_t algo = CUSPARSE_SPMM_ALG_DEFAULT;
152152 const cudaDataType comptp = dtc; // TODO: support separate comp type?!
153- ret = cusparseSpMM_bufferSize (handle_, params_.opA , params_.opB ,
154- &salpha_, matA_, matB_, &sbeta_,
155- matC_, comptp, algo, &workspaceSize_);
153+ ret = cusparseSpMM_bufferSize (handle_, params_.opA , params_.opB , &salpha_,
154+ matA_, matB_, &sbeta_, matC_, comptp, algo ,
155+ &workspaceSize_);
156156 MATX_ASSERT (ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError);
157157 if (workspaceSize_) {
158158 matxAlloc ((void **)&workspace_, workspaceSize_, MATX_DEVICE_MEMORY);
@@ -203,8 +203,8 @@ class MatMulCUSPARSEHandle_t {
203203 const cusparseSpMMAlg_t algo = CUSPARSE_SPMM_ALG_DEFAULT;
204204 const cudaDataType comptp = MatXTypeToCudaType<TC>(); // TODO: see above
205205 [[maybe_unused]] cusparseStatus_t ret =
206- cusparseSpMM (handle_, params_.opA , params_.opB , &salpha_, matA_,
207- matB_, &sbeta_, matC_, comptp, algo, workspace_);
206+ cusparseSpMM (handle_, params_.opA , params_.opB , &salpha_, matA_, matB_,
207+ &sbeta_, matC_, comptp, algo, workspace_);
208208 MATX_ASSERT (ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError);
209209 }
210210
@@ -255,49 +255,69 @@ using gemm_cusparse_cache_t =
255255 std::unordered_map<MatMulCUSPARSEParams_t, std::any,
256256 MatMulCUSPARSEParamsKeyHash, MatMulCUSPARSEParamsKeyEq>;
257257
258+ template <typename Op>
259+ __MATX_INLINE__ auto getCuSparseSupportedTensor (const Op &in,
260+ cudaStream_t stream) {
261+ const auto func = [&]() {
262+ if constexpr (is_tensor_view_v<Op>)
263+ return in.Stride (Op::Rank () - 1 ) == 1 ;
264+ return true ;
265+ };
266+ return GetSupportedTensor (in, func, MATX_ASYNC_DEVICE_MEMORY, stream);
267+ }
268+
258269} // end namespace detail
259270
260271template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
261- void sparse_matmul_impl (TensorTypeC &c , const TensorTypeA &a, const TensorTypeB &b ,
262- const cudaExecutor &exec, float alpha = 1.0 ,
263- float beta = 0.0 ) {
272+ void sparse_matmul_impl (TensorTypeC &C , const TensorTypeA &a,
273+ const TensorTypeB &B, const cudaExecutor &exec ,
274+ float alpha = 1.0 , float beta = 0.0 ) {
264275 MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_API)
265276 const auto stream = exec.getStream ();
266277
267- using TA = typename TensorTypeA::value_type;
268- using TB = typename TensorTypeB::value_type;
269- using TC = typename TensorTypeC::value_type;
278+ // Transform into supported form.
279+ auto b = getCuSparseSupportedTensor (B, stream);
280+ auto c = getCuSparseSupportedTensor (C, stream);
281+ if (!is_matx_transform_op<TensorTypeB>() && !b.isSameView (B)) {
282+ (b = B).run (stream);
283+ }
284+
285+ using atype = TensorTypeA;
286+ using btype = decltype (b);
287+ using ctype = decltype (c);
288+
289+ using TA = typename atype::value_type;
290+ using TB = typename btype::value_type;
291+ using TC = typename ctype::value_type;
270292
271- static constexpr int RANKA = TensorTypeA ::Rank ();
272- static constexpr int RANKB = TensorTypeB ::Rank ();
273- static constexpr int RANKC = TensorTypeC ::Rank ();
293+ static constexpr int RANKA = atype ::Rank ();
294+ static constexpr int RANKB = btype ::Rank ();
295+ static constexpr int RANKC = ctype ::Rank ();
274296
275297 // Restrictions.
276298 static_assert (RANKA == 2 && RANKB == 2 && RANKC == 2 ,
277299 " tensors must have rank-2" );
278- static_assert (std::is_same_v<TC, TA> &&
279- std::is_same_v<TC, TB>,
300+ static_assert (std::is_same_v<TC, TA> && std::is_same_v<TC, TB>,
280301 " tensors must have the same data type" );
281302 // TODO: allow MIXED-PRECISION computation!
282- static_assert (std::is_same_v<TC, float > ||
283- std::is_same_v<TC, double > ||
284- std::is_same_v<TC, cuda::std::complex <float >> ||
285- std::is_same_v<TC, cuda::std::complex <double >>,
303+ static_assert (std::is_same_v<TC, float > || std::is_same_v<TC, double > ||
304+ std::is_same_v<TC, cuda::std::complex <float >> ||
305+ std::is_same_v<TC, cuda::std::complex <double >>,
286306 " unsupported data type" );
287- MATX_ASSERT (
288- a .Size (RANKA - 1 ) == b.Size (RANKB - 2 ) &&
289- c.Size (RANKC - 1 ) == b .Size (RANKB - 1 ) &&
290- c. Size (RANKC - 2 ) == a. Size (RANKA - 2 ), matxInvalidSize);
291- MATX_ASSERT (b.Stride (RANKB - 1 ) == 1 &&
292- c. Stride (RANKC - 1 ) == 1 , matxInvalidParameter);
307+ MATX_ASSERT (a. Size (RANKA - 1 ) == b. Size (RANKB - 2 ) &&
308+ c .Size (RANKC - 1 ) == b.Size (RANKB - 1 ) &&
309+ c.Size (RANKC - 2 ) == a .Size (RANKA - 2 ),
310+ matxInvalidSize);
311+ MATX_ASSERT (b.Stride (RANKB - 1 ) == 1 && c. Stride (RANKC - 1 ) == 1 ,
312+ matxInvalidParameter);
293313
294314 // Get parameters required by these tensors (for caching).
295315 auto params =
296- detail::MatMulCUSPARSEHandle_t<TensorTypeC, TensorTypeA, TensorTypeB >::GetGemmParams (
316+ detail::MatMulCUSPARSEHandle_t<ctype, atype, btype >::GetGemmParams (
297317 c, a, b, stream, alpha, beta);
298318
299319 // Lookup and cache.
300- using cache_val_type = detail::MatMulCUSPARSEHandle_t<TensorTypeC, TensorTypeA, TensorTypeB >;
320+ using cache_val_type = detail::MatMulCUSPARSEHandle_t<ctype, atype, btype >;
301321 detail::GetCache ().LookupAndExec <detail::gemm_cusparse_cache_t >(
302322 detail::GetCacheIdFromType<detail::gemm_cusparse_cache_t >(), params,
303323 [&]() {
@@ -306,6 +326,11 @@ void sparse_matmul_impl(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB
306326 [&](std::shared_ptr<cache_val_type> cache_type) {
307327 cache_type->Exec (c, a, b);
308328 });
329+
330+ // Copy transformed output back.
331+ if (!c.isSameView (C)) {
332+ (C = c).run (stream);
333+ }
309334}
310335
311336} // end namespace matx
0 commit comments