Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions include/matx/transforms/convert/dense2sparse_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
Expand Down Expand Up @@ -247,43 +247,56 @@ using dense2sparse_cache_t =
std::unordered_map<Dense2SparseParams_t, std::any,
Dense2SparseParamsKeyHash, Dense2SparseParamsKeyEq>;

template <typename Op>
__MATX_INLINE__ auto getD2SSupportedTensor(const Op &in, cudaStream_t stream) {
const auto func = [&]() {
if constexpr (is_tensor_view_v<Op>)
return in.Stride(Op::Rank() - 1) == 1;
return true;
};
return GetSupportedTensor(in, func, MATX_ASYNC_DEVICE_MEMORY, stream);
}

} // end namespace detail

template <typename OutputTensorType, typename InputTensorType>
void dense2sparse_impl(OutputTensorType &o, const InputTensorType &a,
void dense2sparse_impl(OutputTensorType &o, const InputTensorType &A,
const cudaExecutor &exec) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
const auto stream = exec.getStream();

using TA = typename InputTensorType::value_type;
using TO = typename OutputTensorType::value_type;
// Transform into supported form.
auto a = getD2SSupportedTensor(A, stream);
if (!is_matx_transform_op<InputTensorType>() && !a.isSameView(A)) {
(a = A).run(stream);
}
using atype = decltype(a);
using otype = OutputTensorType;

using TA = typename atype::value_type;
using TO = typename otype::value_type;

static constexpr int RANKA = InputTensorType::Rank();
static constexpr int RANKO = OutputTensorType::Rank();
static constexpr int RANKA = atype::Rank();
static constexpr int RANKO = otype::Rank();

// Restrictions.
static_assert(RANKA == RANKO, "tensors must have same rank");
static_assert(std::is_same_v<TA, TO>,
"tensors must have the same data type");
static_assert(std::is_same_v<TA, TO>, "tensors must have the same data type");
static_assert(std::is_same_v<TO, int8_t> ||
std::is_same_v<TO, matx::matxFp16> ||
std::is_same_v<TO, matx::matxBf16> ||
std::is_same_v<TO, float> ||
std::is_same_v<TO, double> ||
std::is_same_v<TO, cuda::std::complex<float>> ||
std::is_same_v<TO, cuda::std::complex<double>>,
std::is_same_v<TO, matx::matxFp16> ||
std::is_same_v<TO, matx::matxBf16> ||
std::is_same_v<TO, float> || std::is_same_v<TO, double> ||
std::is_same_v<TO, cuda::std::complex<float>> ||
std::is_same_v<TO, cuda::std::complex<double>>,
"unsupported data type");
MATX_ASSERT(a.Stride(RANKA - 1) == 1, matxInvalidParameter);

// Get parameters required by these tensors (for caching).
auto params =
detail::Dense2SparseHandle_t<OutputTensorType,
InputTensorType>::GetConvParams(o, a,
stream);
detail::Dense2SparseHandle_t<otype, atype>::GetConvParams(o, a, stream);

// Lookup and cache.
using cache_val_type =
detail::Dense2SparseHandle_t<OutputTensorType, InputTensorType>;
using cache_val_type = detail::Dense2SparseHandle_t<otype, atype>;
detail::GetCache().LookupAndExec<detail::dense2sparse_cache_t>(
detail::GetCacheIdFromType<detail::dense2sparse_cache_t>(), params,
[&]() { return std::make_shared<cache_val_type>(o, a, stream); },
Expand Down
33 changes: 17 additions & 16 deletions include/matx/transforms/convert/sparse2dense_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
Expand Down Expand Up @@ -215,32 +215,33 @@ void sparse2dense_impl(OutputTensorType &o, const InputTensorType &a,
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
const auto stream = exec.getStream();

using TA = typename InputTensorType::value_type;
using TO = typename OutputTensorType::value_type;
using atype = InputTensorType;
using otype = OutputTensorType;

static constexpr int RANKA = InputTensorType::Rank();
static constexpr int RANKO = OutputTensorType::Rank();
using TA = typename atype::value_type;
using TO = typename otype::value_type;

static constexpr int RANKA = atype::Rank();
static constexpr int RANKO = otype::Rank();

// Restrictions.
static_assert(RANKA == RANKO, "tensors must have same rank");
static_assert(std::is_same_v<TA, TO>,
"tensors must have the same data type");
static_assert(std::is_same_v<TA, TO>, "tensors must have the same data type");
static_assert(std::is_same_v<TO, int8_t> ||
std::is_same_v<TO, matx::matxFp16> ||
std::is_same_v<TO, matx::matxBf16> ||
std::is_same_v<TO, float> ||
std::is_same_v<TO, double> ||
std::is_same_v<TO, cuda::std::complex<float>> ||
std::is_same_v<TO, cuda::std::complex<double>>,
std::is_same_v<TO, matx::matxFp16> ||
std::is_same_v<TO, matx::matxBf16> ||
std::is_same_v<TO, float> || std::is_same_v<TO, double> ||
std::is_same_v<TO, cuda::std::complex<float>> ||
std::is_same_v<TO, cuda::std::complex<double>>,
"unsupported data type");
MATX_ASSERT(o.Stride(RANKO - 1) == 1, matxInvalidParameter);

// Get parameters required by these tensors (for caching).
auto params =
detail::Sparse2DenseHandle_t<OutputTensorType, InputTensorType>::GetConvParams(o, a, stream);
detail::Sparse2DenseHandle_t<otype, atype>::GetConvParams(o, a, stream);

// Lookup and cache.
using cache_val_type = detail::Sparse2DenseHandle_t<OutputTensorType, InputTensorType>;
using cache_val_type = detail::Sparse2DenseHandle_t<otype, atype>;
detail::GetCache().LookupAndExec<detail::sparse2dense_cache_t>(
detail::GetCacheIdFromType<detail::sparse2dense_cache_t>(), params,
[&]() { return std::make_shared<cache_val_type>(o, a, stream); },
Expand Down
12 changes: 9 additions & 3 deletions test/00_sparse/Convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,21 @@ TYPED_TEST(ConvertSparseTestsAll, ConvertCSC) {
}
}

// Allow dense computations (post-convert).
// Allow dense computations (pre-convert).
TestType C3 = static_cast<TestType>(3);
(O = sparse2dense(S) + C3).run(exec);
(S = dense2sparse(D + C3)).run(exec);

ASSERT_EQ(S.Nse(), 100); // fully dense now

// Allow dense computations (post-convert).
TestType C5 = static_cast<TestType>(5);
(O = sparse2dense(S) + C5).run(exec);

// Verify result.
exec.sync();
for (index_t i = 0; i < m; i++) {
for (index_t j = 0; j < n; j++) {
ASSERT_EQ(O(i, j) - C3, D(i, j));
ASSERT_EQ(O(i, j) - C5, D(i, j) + C3);
}
}

Expand Down