Skip to content

Commit cbc7dbb

Browse files
q10facebook-github-bot
authored andcommitted
Migrate TBE SSD cache kernels to FBGEMM_LAUNCH_KERNEL (#4142)
Summary: Pull Request resolved: #4142 - Migrate TBE SSD cache kernels to `FBGEMM_LAUNCH_KERNEL` Reviewed By: r-barnes Differential Revision: D74440538 fbshipit-source-id: 578d81632adf965e1341e7cde0d3325ac02e481d
1 parent 3776e72 commit cbc7dbb

File tree

1 file changed

+55
-67
lines changed

1 file changed

+55
-67
lines changed

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu

Lines changed: 55 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "fbgemm_gpu/split_embeddings_utils.cuh"
1919
#include "fbgemm_gpu/utils/bitonic_sort.cuh"
2020
#include "fbgemm_gpu/utils/dispatch_macros.h"
21+
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
2122
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
2223
#include "fbgemm_gpu/utils/tensor_utils.h"
2324
#include "fbgemm_gpu/utils/vec4.cuh"
@@ -148,16 +149,16 @@ Tensor masked_index_impl(
148149
is_index_put ? "masked_index_put" : "masked_index_select",
149150
[&] {
150151
using index_t = scalar_t;
151-
masked_index_kernel<value_t, index_t, is_index_put>
152-
<<<grid_size,
153-
dim3(tx, kMaxThreads / tx),
154-
0,
155-
at::cuda::getCurrentCUDAStream()>>>(
156-
MAKE_PTA_WITH_NAME(func_name, self, value_t, 2, 64),
157-
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
158-
MAKE_PTA_WITH_NAME(func_name, values, value_t, 2, 64),
159-
MAKE_PTA_WITH_NAME(func_name, count, int32_t, 1, 32));
160-
C10_CUDA_KERNEL_LAUNCH_CHECK();
152+
FBGEMM_LAUNCH_KERNEL(
153+
(masked_index_kernel<value_t, index_t, is_index_put>),
154+
grid_size,
155+
dim3(tx, kMaxThreads / tx),
156+
0,
157+
at::cuda::getCurrentCUDAStream(),
158+
PTA_B(self, value_t, 2, 64),
159+
PTA_B(indices, index_t, 1, 32),
160+
PTA_B(values, value_t, 2, 64),
161+
PTA_B(count, int32_t, 1, 32));
161162
} // lambda for FBGEMM_DISPATCH_INTEGRAL_TYPES
162163
);
163164
} // lambda for FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE
@@ -462,21 +463,22 @@ ssd_cache_populate_actions_cuda(
462463

463464
template <typename index_t>
464465
__global__ __launch_bounds__(kMaxThreads) void ssd_generate_row_addrs_kernel(
465-
at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> ssd_row_addrs,
466-
at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
466+
pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
467+
ssd_row_addrs,
468+
pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
467469
post_bwd_evicted_indices,
468-
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
470+
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
469471
lxu_cache_locations,
470-
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
472+
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
471473
assigned_cache_slots,
472-
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
474+
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
473475
linear_index_inverse_indices,
474476
// TODO: Use int64_t here
475-
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
477+
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
476478
unique_indices_count_cumsum,
477-
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
479+
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
478480
cache_set_inverse_indices,
479-
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
481+
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
480482
cache_set_sorted_unique_indices,
481483
const uint64_t lxu_cache_weights_addr,
482484
const uint64_t inserted_ssd_weights_addr,
@@ -560,50 +562,42 @@ std::tuple<Tensor, Tensor> ssd_generate_row_addrs_cuda(
560562
"ssd_generate_row_addrs",
561563
[&] {
562564
using index_t = scalar_t;
563-
ssd_generate_row_addrs_kernel<<<
565+
FBGEMM_LAUNCH_KERNEL(
566+
(ssd_generate_row_addrs_kernel<index_t>),
564567
div_round_up(lxu_cache_locations.numel(), kNumWarps),
565568
dim3(kWarpSize, kNumWarps),
566569
0,
567-
at::cuda::getCurrentCUDAStream()>>>(
568-
ssd_row_addrs
569-
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
570-
post_bwd_evicted_indices
571-
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
572-
lxu_cache_locations
573-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
574-
assigned_cache_slots
575-
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
576-
linear_index_inverse_indices
577-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
578-
unique_indices_count_cumsum
579-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
580-
cache_set_inverse_indices
581-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
582-
cache_set_sorted_unique_indices
583-
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
570+
at::cuda::getCurrentCUDAStream(),
571+
PTA_B(ssd_row_addrs, int64_t, 1, 32),
572+
PTA_B(post_bwd_evicted_indices, int64_t, 1, 32),
573+
PTA_B(lxu_cache_locations, int32_t, 1, 32),
574+
PTA_B(assigned_cache_slots, index_t, 1, 32),
575+
PTA_B(linear_index_inverse_indices, int32_t, 1, 32),
576+
PTA_B(unique_indices_count_cumsum, int32_t, 1, 32),
577+
PTA_B(cache_set_inverse_indices, int32_t, 1, 32),
578+
PTA_B(cache_set_sorted_unique_indices, int64_t, 1, 32),
584579
lxu_cache_weights_addr,
585580
reinterpret_cast<uint64_t>(inserted_ssd_weights.data_ptr()),
586581
unique_indices_length.data_ptr<int32_t>(),
587582
cache_row_bytes);
588-
C10_CUDA_KERNEL_LAUNCH_CHECK();
589583
} // lambda
590584
);
591585

592586
return {ssd_row_addrs, post_bwd_evicted_indices};
593587
}
594588

595589
__global__ __launch_bounds__(kMaxThreads) void ssd_update_row_addrs_kernel(
596-
at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
590+
pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
597591
ssd_row_addrs_curr,
598-
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
592+
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
599593
ssd_curr_next_map,
600-
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
594+
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
601595
lxu_cache_locations_curr,
602-
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
596+
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
603597
linear_index_inverse_indices_curr,
604-
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
598+
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
605599
unique_indices_count_cumsum_curr,
606-
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
600+
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
607601
cache_set_inverse_indices_curr,
608602
const uint64_t lxu_cache_weights_addr,
609603
const uint64_t inserted_ssd_weights_addr_next,
@@ -679,26 +673,22 @@ void ssd_update_row_addrs_cuda(
679673
lxu_cache_weights.size(1) * lxu_cache_weights.element_size();
680674
constexpr auto kNumWarps = kMaxThreads / kWarpSize;
681675

682-
ssd_update_row_addrs_kernel<<<
676+
FBGEMM_LAUNCH_KERNEL(
677+
(ssd_update_row_addrs_kernel),
683678
div_round_up(ssd_row_addrs_curr.numel(), kNumWarps),
684679
dim3(kWarpSize, kNumWarps),
685680
0,
686-
at::cuda::getCurrentCUDAStream()>>>(
687-
ssd_row_addrs_curr.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
688-
ssd_curr_next_map.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
689-
lxu_cache_locations_curr
690-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
691-
linear_index_inverse_indices_curr
692-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
693-
unique_indices_count_cumsum_curr
694-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
695-
cache_set_inverse_indices_curr
696-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
681+
at::cuda::getCurrentCUDAStream(),
682+
PTA_B(ssd_row_addrs_curr, int64_t, 1, 32),
683+
PTA_B(ssd_curr_next_map, int32_t, 1, 32),
684+
PTA_B(lxu_cache_locations_curr, int32_t, 1, 32),
685+
PTA_B(linear_index_inverse_indices_curr, int32_t, 1, 32),
686+
PTA_B(unique_indices_count_cumsum_curr, int32_t, 1, 32),
687+
PTA_B(cache_set_inverse_indices_curr, int32_t, 1, 32),
697688
lxu_cache_weights_addr,
698689
inserted_ssd_weights_addr_next,
699690
unique_indices_length_curr.data_ptr<int32_t>(),
700691
cache_row_bytes);
701-
C10_CUDA_KERNEL_LAUNCH_CHECK();
702692
}
703693

704694
template <typename index_t, typename offset_t>
@@ -772,22 +762,20 @@ void compact_indices_cuda(
772762
FBGEMM_DISPATCH_INTEGRAL_TYPES(
773763
offsets.scalar_type(), "compact_indices", [&] {
774764
using offset_t = scalar_t;
775-
compact_indices_kernel<<<
765+
FBGEMM_LAUNCH_KERNEL(
766+
(compact_indices_kernel<index_t, offset_t>),
776767
// Launch N + 1 threads because we need at least one thread
777768
// to set compact_count
778769
div_round_up(N + 1, kMaxThreads),
779770
kMaxThreads,
780771
0,
781-
at::cuda::getCurrentCUDAStream()>>>(
782-
MAKE_PTA_WITH_NAME(
783-
func_name, compact_indices[i], index_t, 1, 32),
784-
MAKE_PTA_WITH_NAME(
785-
func_name, compact_count, int32_t, 1, 32),
786-
MAKE_PTA_WITH_NAME(func_name, indices[i], index_t, 1, 32),
787-
MAKE_PTA_WITH_NAME(func_name, offsets, offset_t, 1, 32),
788-
MAKE_PTA_WITH_NAME(func_name, masks, offset_t, 1, 32),
789-
MAKE_PTA_WITH_NAME(func_name, count, int32_t, 1, 32));
790-
C10_CUDA_KERNEL_LAUNCH_CHECK();
772+
at::cuda::getCurrentCUDAStream(),
773+
PTA_B(compact_indices[i], index_t, 1, 32),
774+
PTA_B(compact_count, int32_t, 1, 32),
775+
PTA_B(indices[i], index_t, 1, 32),
776+
PTA_B(offsets, offset_t, 1, 32),
777+
PTA_B(masks, offset_t, 1, 32),
778+
PTA_B(count, int32_t, 1, 32));
791779
});
792780
} // lambda
793781
);

0 commit comments

Comments
 (0)