|
18 | 18 | #include "fbgemm_gpu/split_embeddings_utils.cuh"
|
19 | 19 | #include "fbgemm_gpu/utils/bitonic_sort.cuh"
|
20 | 20 | #include "fbgemm_gpu/utils/dispatch_macros.h"
|
| 21 | +#include "fbgemm_gpu/utils/kernel_launcher.cuh" |
21 | 22 | #include "fbgemm_gpu/utils/tensor_accessor_builder.h"
|
22 | 23 | #include "fbgemm_gpu/utils/tensor_utils.h"
|
23 | 24 | #include "fbgemm_gpu/utils/vec4.cuh"
|
@@ -148,16 +149,16 @@ Tensor masked_index_impl(
|
148 | 149 | is_index_put ? "masked_index_put" : "masked_index_select",
|
149 | 150 | [&] {
|
150 | 151 | using index_t = scalar_t;
|
151 |
| - masked_index_kernel<value_t, index_t, is_index_put> |
152 |
| - <<<grid_size, |
153 |
| - dim3(tx, kMaxThreads / tx), |
154 |
| - 0, |
155 |
| - at::cuda::getCurrentCUDAStream()>>>( |
156 |
| - MAKE_PTA_WITH_NAME(func_name, self, value_t, 2, 64), |
157 |
| - MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), |
158 |
| - MAKE_PTA_WITH_NAME(func_name, values, value_t, 2, 64), |
159 |
| - MAKE_PTA_WITH_NAME(func_name, count, int32_t, 1, 32)); |
160 |
| - C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 152 | + FBGEMM_LAUNCH_KERNEL( |
| 153 | + (masked_index_kernel<value_t, index_t, is_index_put>), |
| 154 | + grid_size, |
| 155 | + dim3(tx, kMaxThreads / tx), |
| 156 | + 0, |
| 157 | + at::cuda::getCurrentCUDAStream(), |
| 158 | + PTA_B(self, value_t, 2, 64), |
| 159 | + PTA_B(indices, index_t, 1, 32), |
| 160 | + PTA_B(values, value_t, 2, 64), |
| 161 | + PTA_B(count, int32_t, 1, 32)); |
161 | 162 | } // lambda for FBGEMM_DISPATCH_INTEGRAL_TYPES
|
162 | 163 | );
|
163 | 164 | } // lambda for FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE
|
@@ -462,21 +463,22 @@ ssd_cache_populate_actions_cuda(
|
462 | 463 |
|
463 | 464 | template <typename index_t>
|
464 | 465 | __global__ __launch_bounds__(kMaxThreads) void ssd_generate_row_addrs_kernel(
|
465 |
| - at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> ssd_row_addrs, |
466 |
| - at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> |
| 466 | + pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> |
| 467 | + ssd_row_addrs, |
| 468 | + pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> |
467 | 469 | post_bwd_evicted_indices,
|
468 |
| - const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
| 470 | + const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
469 | 471 | lxu_cache_locations,
|
470 |
| - const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> |
| 472 | + const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> |
471 | 473 | assigned_cache_slots,
|
472 |
| - const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
| 474 | + const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
473 | 475 | linear_index_inverse_indices,
|
474 | 476 | // TODO: Use int64_t here
|
475 |
| - const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
| 477 | + const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
476 | 478 | unique_indices_count_cumsum,
|
477 |
| - const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
| 479 | + const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
478 | 480 | cache_set_inverse_indices,
|
479 |
| - const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> |
| 481 | + const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> |
480 | 482 | cache_set_sorted_unique_indices,
|
481 | 483 | const uint64_t lxu_cache_weights_addr,
|
482 | 484 | const uint64_t inserted_ssd_weights_addr,
|
@@ -560,50 +562,42 @@ std::tuple<Tensor, Tensor> ssd_generate_row_addrs_cuda(
|
560 | 562 | "ssd_generate_row_addrs",
|
561 | 563 | [&] {
|
562 | 564 | using index_t = scalar_t;
|
563 |
| - ssd_generate_row_addrs_kernel<<< |
| 565 | + FBGEMM_LAUNCH_KERNEL( |
| 566 | + (ssd_generate_row_addrs_kernel<index_t>), |
564 | 567 | div_round_up(lxu_cache_locations.numel(), kNumWarps),
|
565 | 568 | dim3(kWarpSize, kNumWarps),
|
566 | 569 | 0,
|
567 |
| - at::cuda::getCurrentCUDAStream()>>>( |
568 |
| - ssd_row_addrs |
569 |
| - .packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), |
570 |
| - post_bwd_evicted_indices |
571 |
| - .packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), |
572 |
| - lxu_cache_locations |
573 |
| - .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), |
574 |
| - assigned_cache_slots |
575 |
| - .packed_accessor32<index_t, 1, at::RestrictPtrTraits>(), |
576 |
| - linear_index_inverse_indices |
577 |
| - .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), |
578 |
| - unique_indices_count_cumsum |
579 |
| - .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), |
580 |
| - cache_set_inverse_indices |
581 |
| - .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), |
582 |
| - cache_set_sorted_unique_indices |
583 |
| - .packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), |
| 570 | + at::cuda::getCurrentCUDAStream(), |
| 571 | + PTA_B(ssd_row_addrs, int64_t, 1, 32), |
| 572 | + PTA_B(post_bwd_evicted_indices, int64_t, 1, 32), |
| 573 | + PTA_B(lxu_cache_locations, int32_t, 1, 32), |
| 574 | + PTA_B(assigned_cache_slots, index_t, 1, 32), |
| 575 | + PTA_B(linear_index_inverse_indices, int32_t, 1, 32), |
| 576 | + PTA_B(unique_indices_count_cumsum, int32_t, 1, 32), |
| 577 | + PTA_B(cache_set_inverse_indices, int32_t, 1, 32), |
| 578 | + PTA_B(cache_set_sorted_unique_indices, int64_t, 1, 32), |
584 | 579 | lxu_cache_weights_addr,
|
585 | 580 | reinterpret_cast<uint64_t>(inserted_ssd_weights.data_ptr()),
|
586 | 581 | unique_indices_length.data_ptr<int32_t>(),
|
587 | 582 | cache_row_bytes);
|
588 |
| - C10_CUDA_KERNEL_LAUNCH_CHECK(); |
589 | 583 | } // lambda
|
590 | 584 | );
|
591 | 585 |
|
592 | 586 | return {ssd_row_addrs, post_bwd_evicted_indices};
|
593 | 587 | }
|
594 | 588 |
|
595 | 589 | __global__ __launch_bounds__(kMaxThreads) void ssd_update_row_addrs_kernel(
|
596 |
| - at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> |
| 590 | + pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> |
597 | 591 | ssd_row_addrs_curr,
|
598 |
| - const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
| 592 | + const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
599 | 593 | ssd_curr_next_map,
|
600 |
| - const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
| 594 | + const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
601 | 595 | lxu_cache_locations_curr,
|
602 |
| - const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
| 596 | + const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
603 | 597 | linear_index_inverse_indices_curr,
|
604 |
| - const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
| 598 | + const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
605 | 599 | unique_indices_count_cumsum_curr,
|
606 |
| - const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
| 600 | + const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> |
607 | 601 | cache_set_inverse_indices_curr,
|
608 | 602 | const uint64_t lxu_cache_weights_addr,
|
609 | 603 | const uint64_t inserted_ssd_weights_addr_next,
|
@@ -679,26 +673,22 @@ void ssd_update_row_addrs_cuda(
|
679 | 673 | lxu_cache_weights.size(1) * lxu_cache_weights.element_size();
|
680 | 674 | constexpr auto kNumWarps = kMaxThreads / kWarpSize;
|
681 | 675 |
|
682 |
| - ssd_update_row_addrs_kernel<<< |
| 676 | + FBGEMM_LAUNCH_KERNEL( |
| 677 | + (ssd_update_row_addrs_kernel), |
683 | 678 | div_round_up(ssd_row_addrs_curr.numel(), kNumWarps),
|
684 | 679 | dim3(kWarpSize, kNumWarps),
|
685 | 680 | 0,
|
686 |
| - at::cuda::getCurrentCUDAStream()>>>( |
687 |
| - ssd_row_addrs_curr.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), |
688 |
| - ssd_curr_next_map.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), |
689 |
| - lxu_cache_locations_curr |
690 |
| - .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), |
691 |
| - linear_index_inverse_indices_curr |
692 |
| - .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), |
693 |
| - unique_indices_count_cumsum_curr |
694 |
| - .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), |
695 |
| - cache_set_inverse_indices_curr |
696 |
| - .packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(), |
| 681 | + at::cuda::getCurrentCUDAStream(), |
| 682 | + PTA_B(ssd_row_addrs_curr, int64_t, 1, 32), |
| 683 | + PTA_B(ssd_curr_next_map, int32_t, 1, 32), |
| 684 | + PTA_B(lxu_cache_locations_curr, int32_t, 1, 32), |
| 685 | + PTA_B(linear_index_inverse_indices_curr, int32_t, 1, 32), |
| 686 | + PTA_B(unique_indices_count_cumsum_curr, int32_t, 1, 32), |
| 687 | + PTA_B(cache_set_inverse_indices_curr, int32_t, 1, 32), |
697 | 688 | lxu_cache_weights_addr,
|
698 | 689 | inserted_ssd_weights_addr_next,
|
699 | 690 | unique_indices_length_curr.data_ptr<int32_t>(),
|
700 | 691 | cache_row_bytes);
|
701 |
| - C10_CUDA_KERNEL_LAUNCH_CHECK(); |
702 | 692 | }
|
703 | 693 |
|
704 | 694 | template <typename index_t, typename offset_t>
|
@@ -772,22 +762,20 @@ void compact_indices_cuda(
|
772 | 762 | FBGEMM_DISPATCH_INTEGRAL_TYPES(
|
773 | 763 | offsets.scalar_type(), "compact_indices", [&] {
|
774 | 764 | using offset_t = scalar_t;
|
775 |
| - compact_indices_kernel<<< |
| 765 | + FBGEMM_LAUNCH_KERNEL( |
| 766 | + (compact_indices_kernel<index_t, offset_t>), |
776 | 767 | // Launch N + 1 threads because we need at least one thread
|
777 | 768 | // to set compact_count
|
778 | 769 | div_round_up(N + 1, kMaxThreads),
|
779 | 770 | kMaxThreads,
|
780 | 771 | 0,
|
781 |
| - at::cuda::getCurrentCUDAStream()>>>( |
782 |
| - MAKE_PTA_WITH_NAME( |
783 |
| - func_name, compact_indices[i], index_t, 1, 32), |
784 |
| - MAKE_PTA_WITH_NAME( |
785 |
| - func_name, compact_count, int32_t, 1, 32), |
786 |
| - MAKE_PTA_WITH_NAME(func_name, indices[i], index_t, 1, 32), |
787 |
| - MAKE_PTA_WITH_NAME(func_name, offsets, offset_t, 1, 32), |
788 |
| - MAKE_PTA_WITH_NAME(func_name, masks, offset_t, 1, 32), |
789 |
| - MAKE_PTA_WITH_NAME(func_name, count, int32_t, 1, 32)); |
790 |
| - C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 772 | + at::cuda::getCurrentCUDAStream(), |
| 773 | + PTA_B(compact_indices[i], index_t, 1, 32), |
| 774 | + PTA_B(compact_count, int32_t, 1, 32), |
| 775 | + PTA_B(indices[i], index_t, 1, 32), |
| 776 | + PTA_B(offsets, offset_t, 1, 32), |
| 777 | + PTA_B(masks, offset_t, 1, 32), |
| 778 | + PTA_B(count, int32_t, 1, 32)); |
791 | 779 | });
|
792 | 780 | } // lambda
|
793 | 781 | );
|
|
0 commit comments