Skip to content

Commit

Permalink
set CUB_VERSION to 200001 for USE_ROCM (pytorch#140861)
Browse files Browse the repository at this point in the history
Summary:
currently, CUB_VERSION is 0 for USE_ROCM
CUB_VERSION is used for determine whether to use advanced cub APIs for some implementation.

Test Plan:
`buck2 build --flagfile fbsource//arvr/mode/win/vs2022/cpp20/cuda12_5/dev --flagfile fbsource//arvr/mode/cuda/rtx30 fbsource//arvr/libraries/eye/apollo_visualizer:unit_test_apollo_hu_module_capability`

`buck2 build --flagfile fbcode//mode/amd-gpu fbcode//aiplatform/modelstore/checkpointing/pyper:tensor_save_load_utils`

Differential Revision: D63054638

Pull Request resolved: pytorch#140861
Approved by: https://github.com/eqy, https://github.com/zoranzhao, https://github.com/houseroad
  • Loading branch information
wangzhenict authored and pytorchmergebot committed Dec 10, 2024
1 parent 2f1191f commit e83b0fa
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 5 deletions.
10 changes: 10 additions & 0 deletions aten/src/ATen/cuda/cub.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -527,16 +527,26 @@ template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename V
inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub InclusiveSumByKey does not support more than INT_MAX elements");
#if !defined(USE_ROCM)
CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey,
keys, input, output, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(cub::DeviceScan::InclusiveSumByKey,
keys, input, output, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream());
#endif
}

template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename ScanOpT>
inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub InclusiveSumByKey does not support more than INT_MAX elements");
#if !defined(USE_ROCM)
CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey,
keys, input, output, scan_op, num_items, at_cuda_detail::cub::Equality(), at::cuda::getCurrentCUDAStream());
#else
CUB_WRAPPER(cub::DeviceScan::InclusiveScanByKey,
keys, input, output, scan_op, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream());
#endif
}

#endif
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/cub_definitions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#if !defined(USE_ROCM)
#include <cub/version.cuh>
#else
#define CUB_VERSION 0
#define CUB_VERSION 200001
#endif

// cub sort support for __nv_bfloat16 is added to cub 1.13 in:
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/Embedding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
auto count_data = count.mutable_data_ptr<index_t>();
cuda::cub::inclusive_sum_by_key(
sorted_data,
at_cuda_detail::cub::ConstantInputIterator<index_t>(1),
NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::ConstantInputIterator<index_t>(1),
count_data,
num_indices
);
Expand All @@ -329,7 +329,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
thrust::make_reverse_iterator(sorted_data + num_indices),
thrust::make_reverse_iterator(static_cast<const index_t*>(count_data) + num_indices),
thrust::make_reverse_iterator(count_data + num_indices),
at_cuda_detail::cub::Max(),
NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max(),
num_indices
);
});
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/EmbeddingBag.cu
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ Tensor embedding_bag_backward_cuda_sum_avg(
auto count_data = count.mutable_data_ptr<index_t>();
cuda::cub::inclusive_sum_by_key(
sorted_data,
at_cuda_detail::cub::ConstantInputIterator<index_t>(1),
NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::ConstantInputIterator<index_t>(1),
count_data,
num_indices
);
Expand All @@ -221,7 +221,7 @@ Tensor embedding_bag_backward_cuda_sum_avg(
thrust::make_reverse_iterator(sorted_data + num_indices),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::make_reverse_iterator(count_data + num_indices),
at_cuda_detail::cub::Max(),
NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max(),
num_indices
);
});
Expand Down

0 comments on commit e83b0fa

Please sign in to comment.