Skip to content

Commit e2fd9d7

Browse files
authored
squash all commits (#881)
1 parent 4fed562 commit e2fd9d7

File tree

3 files changed

+59
-39
lines changed

3 files changed

+59
-39
lines changed

include/matx/transforms/convert/dense2sparse_cusparse.h

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
// Redistribution and use in source and binary forms, with or without
88
// modification, are permitted provided that the following conditions are met:
99
//
10-
// 1. Redistributions of source code must retain the above copyright notice, this
11-
// list of conditions and the following disclaimer.
10+
// 1. Redistributions of source code must retain the above copyright notice,
11+
// this list of conditions and the following disclaimer.
1212
//
1313
// 2. Redistributions in binary form must reproduce the above copyright notice,
1414
// this list of conditions and the following disclaimer in the documentation
@@ -247,43 +247,56 @@ using dense2sparse_cache_t =
247247
std::unordered_map<Dense2SparseParams_t, std::any,
248248
Dense2SparseParamsKeyHash, Dense2SparseParamsKeyEq>;
249249

250+
template <typename Op>
251+
__MATX_INLINE__ auto getD2SSupportedTensor(const Op &in, cudaStream_t stream) {
252+
const auto func = [&]() {
253+
if constexpr (is_tensor_view_v<Op>)
254+
return in.Stride(Op::Rank() - 1) == 1;
255+
return true;
256+
};
257+
return GetSupportedTensor(in, func, MATX_ASYNC_DEVICE_MEMORY, stream);
258+
}
259+
250260
} // end namespace detail
251261

252262
template <typename OutputTensorType, typename InputTensorType>
253-
void dense2sparse_impl(OutputTensorType &o, const InputTensorType &a,
263+
void dense2sparse_impl(OutputTensorType &o, const InputTensorType &A,
254264
const cudaExecutor &exec) {
255265
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
256266
const auto stream = exec.getStream();
257267

258-
using TA = typename InputTensorType::value_type;
259-
using TO = typename OutputTensorType::value_type;
268+
// Transform into supported form.
269+
auto a = getD2SSupportedTensor(A, stream);
270+
if (!is_matx_transform_op<InputTensorType>() && !a.isSameView(A)) {
271+
(a = A).run(stream);
272+
}
273+
using atype = decltype(a);
274+
using otype = OutputTensorType;
275+
276+
using TA = typename atype::value_type;
277+
using TO = typename otype::value_type;
260278

261-
static constexpr int RANKA = InputTensorType::Rank();
262-
static constexpr int RANKO = OutputTensorType::Rank();
279+
static constexpr int RANKA = atype::Rank();
280+
static constexpr int RANKO = otype::Rank();
263281

264282
// Restrictions.
265283
static_assert(RANKA == RANKO, "tensors must have same rank");
266-
static_assert(std::is_same_v<TA, TO>,
267-
"tensors must have the same data type");
284+
static_assert(std::is_same_v<TA, TO>, "tensors must have the same data type");
268285
static_assert(std::is_same_v<TO, int8_t> ||
269-
std::is_same_v<TO, matx::matxFp16> ||
270-
std::is_same_v<TO, matx::matxBf16> ||
271-
std::is_same_v<TO, float> ||
272-
std::is_same_v<TO, double> ||
273-
std::is_same_v<TO, cuda::std::complex<float>> ||
274-
std::is_same_v<TO, cuda::std::complex<double>>,
286+
std::is_same_v<TO, matx::matxFp16> ||
287+
std::is_same_v<TO, matx::matxBf16> ||
288+
std::is_same_v<TO, float> || std::is_same_v<TO, double> ||
289+
std::is_same_v<TO, cuda::std::complex<float>> ||
290+
std::is_same_v<TO, cuda::std::complex<double>>,
275291
"unsupported data type");
276292
MATX_ASSERT(a.Stride(RANKA - 1) == 1, matxInvalidParameter);
277293

278294
// Get parameters required by these tensors (for caching).
279295
auto params =
280-
detail::Dense2SparseHandle_t<OutputTensorType,
281-
InputTensorType>::GetConvParams(o, a,
282-
stream);
296+
detail::Dense2SparseHandle_t<otype, atype>::GetConvParams(o, a, stream);
283297

284298
// Lookup and cache.
285-
using cache_val_type =
286-
detail::Dense2SparseHandle_t<OutputTensorType, InputTensorType>;
299+
using cache_val_type = detail::Dense2SparseHandle_t<otype, atype>;
287300
detail::GetCache().LookupAndExec<detail::dense2sparse_cache_t>(
288301
detail::GetCacheIdFromType<detail::dense2sparse_cache_t>(), params,
289302
[&]() { return std::make_shared<cache_val_type>(o, a, stream); },

include/matx/transforms/convert/sparse2dense_cusparse.h

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
// Redistribution and use in source and binary forms, with or without
88
// modification, are permitted provided that the following conditions are met:
99
//
10-
// 1. Redistributions of source code must retain the above copyright notice, this
11-
// list of conditions and the following disclaimer.
10+
// 1. Redistributions of source code must retain the above copyright notice,
11+
// this list of conditions and the following disclaimer.
1212
//
1313
// 2. Redistributions in binary form must reproduce the above copyright notice,
1414
// this list of conditions and the following disclaimer in the documentation
@@ -215,32 +215,33 @@ void sparse2dense_impl(OutputTensorType &o, const InputTensorType &a,
215215
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
216216
const auto stream = exec.getStream();
217217

218-
using TA = typename InputTensorType::value_type;
219-
using TO = typename OutputTensorType::value_type;
218+
using atype = InputTensorType;
219+
using otype = OutputTensorType;
220220

221-
static constexpr int RANKA = InputTensorType::Rank();
222-
static constexpr int RANKO = OutputTensorType::Rank();
221+
using TA = typename atype::value_type;
222+
using TO = typename otype::value_type;
223+
224+
static constexpr int RANKA = atype::Rank();
225+
static constexpr int RANKO = otype::Rank();
223226

224227
// Restrictions.
225228
static_assert(RANKA == RANKO, "tensors must have same rank");
226-
static_assert(std::is_same_v<TA, TO>,
227-
"tensors must have the same data type");
229+
static_assert(std::is_same_v<TA, TO>, "tensors must have the same data type");
228230
static_assert(std::is_same_v<TO, int8_t> ||
229-
std::is_same_v<TO, matx::matxFp16> ||
230-
std::is_same_v<TO, matx::matxBf16> ||
231-
std::is_same_v<TO, float> ||
232-
std::is_same_v<TO, double> ||
233-
std::is_same_v<TO, cuda::std::complex<float>> ||
234-
std::is_same_v<TO, cuda::std::complex<double>>,
231+
std::is_same_v<TO, matx::matxFp16> ||
232+
std::is_same_v<TO, matx::matxBf16> ||
233+
std::is_same_v<TO, float> || std::is_same_v<TO, double> ||
234+
std::is_same_v<TO, cuda::std::complex<float>> ||
235+
std::is_same_v<TO, cuda::std::complex<double>>,
235236
"unsupported data type");
236237
MATX_ASSERT(o.Stride(RANKO - 1) == 1, matxInvalidParameter);
237238

238239
// Get parameters required by these tensors (for caching).
239240
auto params =
240-
detail::Sparse2DenseHandle_t<OutputTensorType, InputTensorType>::GetConvParams(o, a, stream);
241+
detail::Sparse2DenseHandle_t<otype, atype>::GetConvParams(o, a, stream);
241242

242243
// Lookup and cache.
243-
using cache_val_type = detail::Sparse2DenseHandle_t<OutputTensorType, InputTensorType>;
244+
using cache_val_type = detail::Sparse2DenseHandle_t<otype, atype>;
244245
detail::GetCache().LookupAndExec<detail::sparse2dense_cache_t>(
245246
detail::GetCacheIdFromType<detail::sparse2dense_cache_t>(), params,
246247
[&]() { return std::make_shared<cache_val_type>(o, a, stream); },

test/00_sparse/Convert.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,21 @@ TYPED_TEST(ConvertSparseTestsAll, ConvertCSC) {
196196
}
197197
}
198198

199-
// Allow dense computations (post-convert).
199+
// Allow dense computations (pre-convert).
200200
TestType C3 = static_cast<TestType>(3);
201-
(O = sparse2dense(S) + C3).run(exec);
201+
(S = dense2sparse(D + C3)).run(exec);
202+
203+
ASSERT_EQ(S.Nse(), 100); // fully dense now
204+
205+
// Allow dense computations (post-convert).
206+
TestType C5 = static_cast<TestType>(5);
207+
(O = sparse2dense(S) + C5).run(exec);
202208

203209
// Verify result.
204210
exec.sync();
205211
for (index_t i = 0; i < m; i++) {
206212
for (index_t j = 0; j < n; j++) {
207-
ASSERT_EQ(O(i, j) - C3, D(i, j));
213+
ASSERT_EQ(O(i, j) - C5, D(i, j) + C3);
208214
}
209215
}
210216

0 commit comments

Comments
 (0)