|
7 | 7 | // Redistribution and use in source and binary forms, with or without |
8 | 8 | // modification, are permitted provided that the following conditions are met: |
9 | 9 | // |
10 | | -// 1. Redistributions of source code must retain the above copyright notice, this |
11 | | -// list of conditions and the following disclaimer. |
| 10 | +// 1. Redistributions of source code must retain the above copyright notice, |
| 11 | +// this list of conditions and the following disclaimer. |
12 | 12 | // |
13 | 13 | // 2. Redistributions in binary form must reproduce the above copyright notice, |
14 | 14 | // this list of conditions and the following disclaimer in the documentation |
@@ -247,43 +247,56 @@ using dense2sparse_cache_t = |
247 | 247 | std::unordered_map<Dense2SparseParams_t, std::any, |
248 | 248 | Dense2SparseParamsKeyHash, Dense2SparseParamsKeyEq>; |
249 | 249 |
|
| 250 | +template <typename Op> |
| 251 | +__MATX_INLINE__ auto getD2SSupportedTensor(const Op &in, cudaStream_t stream) { |
| 252 | + const auto func = [&]() { |
| 253 | + if constexpr (is_tensor_view_v<Op>) |
| 254 | + return in.Stride(Op::Rank() - 1) == 1; |
| 255 | + return true; |
| 256 | + }; |
| 257 | + return GetSupportedTensor(in, func, MATX_ASYNC_DEVICE_MEMORY, stream); |
| 258 | +} |
| 259 | + |
250 | 260 | } // end namespace detail |
251 | 261 |
|
252 | 262 | template <typename OutputTensorType, typename InputTensorType> |
253 | | -void dense2sparse_impl(OutputTensorType &o, const InputTensorType &a, |
| 263 | +void dense2sparse_impl(OutputTensorType &o, const InputTensorType &A, |
254 | 264 | const cudaExecutor &exec) { |
255 | 265 | MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) |
256 | 266 | const auto stream = exec.getStream(); |
257 | 267 |
|
258 | | - using TA = typename InputTensorType::value_type; |
259 | | - using TO = typename OutputTensorType::value_type; |
| 268 | + // Transform into supported form. |
| 269 | + auto a = getD2SSupportedTensor(A, stream); |
| 270 | + if (!is_matx_transform_op<InputTensorType>() && !a.isSameView(A)) { |
| 271 | + (a = A).run(stream); |
| 272 | + } |
| 273 | + using atype = decltype(a); |
| 274 | + using otype = OutputTensorType; |
| 275 | + |
| 276 | + using TA = typename atype::value_type; |
| 277 | + using TO = typename otype::value_type; |
260 | 278 |
|
261 | | - static constexpr int RANKA = InputTensorType::Rank(); |
262 | | - static constexpr int RANKO = OutputTensorType::Rank(); |
| 279 | + static constexpr int RANKA = atype::Rank(); |
| 280 | + static constexpr int RANKO = otype::Rank(); |
263 | 281 |
|
264 | 282 | // Restrictions. |
265 | 283 | static_assert(RANKA == RANKO, "tensors must have same rank"); |
266 | | - static_assert(std::is_same_v<TA, TO>, |
267 | | - "tensors must have the same data type"); |
| 284 | + static_assert(std::is_same_v<TA, TO>, "tensors must have the same data type"); |
268 | 285 | static_assert(std::is_same_v<TO, int8_t> || |
269 | | - std::is_same_v<TO, matx::matxFp16> || |
270 | | - std::is_same_v<TO, matx::matxBf16> || |
271 | | - std::is_same_v<TO, float> || |
272 | | - std::is_same_v<TO, double> || |
273 | | - std::is_same_v<TO, cuda::std::complex<float>> || |
274 | | - std::is_same_v<TO, cuda::std::complex<double>>, |
| 286 | + std::is_same_v<TO, matx::matxFp16> || |
| 287 | + std::is_same_v<TO, matx::matxBf16> || |
| 288 | + std::is_same_v<TO, float> || std::is_same_v<TO, double> || |
| 289 | + std::is_same_v<TO, cuda::std::complex<float>> || |
| 290 | + std::is_same_v<TO, cuda::std::complex<double>>, |
275 | 291 | "unsupported data type"); |
276 | 292 | MATX_ASSERT(a.Stride(RANKA - 1) == 1, matxInvalidParameter); |
277 | 293 |
|
278 | 294 | // Get parameters required by these tensors (for caching). |
279 | 295 | auto params = |
280 | | - detail::Dense2SparseHandle_t<OutputTensorType, |
281 | | - InputTensorType>::GetConvParams(o, a, |
282 | | - stream); |
| 296 | + detail::Dense2SparseHandle_t<otype, atype>::GetConvParams(o, a, stream); |
283 | 297 |
|
284 | 298 | // Lookup and cache. |
285 | | - using cache_val_type = |
286 | | - detail::Dense2SparseHandle_t<OutputTensorType, InputTensorType>; |
| 299 | + using cache_val_type = detail::Dense2SparseHandle_t<otype, atype>; |
287 | 300 | detail::GetCache().LookupAndExec<detail::dense2sparse_cache_t>( |
288 | 301 | detail::GetCacheIdFromType<detail::dense2sparse_cache_t>(), params, |
289 | 302 | [&]() { return std::make_shared<cache_val_type>(o, a, stream); }, |
|
0 commit comments