Skip to content

Commit 127848a

Browse files
q10facebook-github-bot
authored andcommitted
Migrate TBE cache kernels to FBGEMM_LAUNCH_KERNEL (#4127)
Summary: X-link: facebookresearch/FBGEMM#1208 Pull Request resolved: #4127 - Migrate TBE cache kernels to `FBGEMM_LAUNCH_KERNEL` Reviewed By: spcyppt Differential Revision: D74272500 fbshipit-source-id: 98c71b6286d3d7aad565cb1ae51111fac37069a5
1 parent cd605e8 commit 127848a

File tree

5 files changed

+94
-134
lines changed

5 files changed

+94
-134
lines changed

fbgemm_gpu/src/split_embeddings_cache/common.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "fbgemm_gpu/utils/cuda_prelude.cuh"
3535
#include "fbgemm_gpu/utils/find_qparams.cuh"
3636
#include "fbgemm_gpu/utils/fixed_divisor.cuh"
37+
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
3738
#include "fbgemm_gpu/utils/stochastic_rounding.cuh"
3839
#include "fbgemm_gpu/utils/vec4.cuh"
3940
#include "fbgemm_gpu/utils/vec4acc.cuh"

fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -198,35 +198,27 @@ void lfu_cache_insert_cuda(
198198
->philox_cuda_state(4);
199199
}
200200

201-
#ifdef FBGEMM_GPU_MEMCHECK
202-
const char* func_name = "lfu_cache_insert_kernel";
203-
#endif
204-
205-
lfu_cache_insert_kernel<emb_t, cache_t>
206-
<<<std::min(
207-
div_round_up(N, kCacheMaxThreads / kWarpSize),
208-
get_max_thread_blocks_for_cache_kernels_()),
209-
dim3(kWarpSize, kCacheMaxThreads / kWarpSize),
210-
0,
211-
at::cuda::getCurrentCUDAStream()>>>(
212-
MAKE_PTA_WITH_NAME(func_name, weights, emb_t, 1, 64),
213-
MAKE_PTA_WITH_NAME(
214-
func_name, cache_hash_size_cumsum, int64_t, 1, 32),
215-
MAKE_PTA_WITH_NAME(
216-
func_name, cache_index_table_map, int32_t, 1, 64),
217-
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
218-
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
219-
(uint64_t*)sorted_cache_sets.data_ptr<int64_t>(),
220-
MAKE_PTA_WITH_NAME(
221-
func_name, cache_set_sorted_unique_indices, int64_t, 1, 32),
222-
unique_indices_length.data_ptr<int32_t>(),
223-
MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32),
224-
MAKE_PTA_WITH_NAME(
225-
func_name, lxu_cache_weights, cache_t, 2, 64),
226-
MAKE_PTA_WITH_NAME(func_name, lfu_state, int64_t, 1, 64),
227-
stochastic_rounding_,
228-
rng_engine_inputs);
229-
C10_CUDA_KERNEL_LAUNCH_CHECK();
201+
FBGEMM_LAUNCH_KERNEL(
202+
(lfu_cache_insert_kernel<emb_t, cache_t>),
203+
std::min(
204+
div_round_up(N, kCacheMaxThreads / kWarpSize),
205+
get_max_thread_blocks_for_cache_kernels_()),
206+
dim3(kWarpSize, kCacheMaxThreads / kWarpSize),
207+
0,
208+
at::cuda::getCurrentCUDAStream(),
209+
PTA_B(weights, emb_t, 1, 64),
210+
PTA_B(cache_hash_size_cumsum, int64_t, 1, 32),
211+
PTA_B(cache_index_table_map, int32_t, 1, 64),
212+
PTA_B(weights_offsets, int64_t, 1, 32),
213+
PTA_B(D_offsets, int32_t, 1, 32),
214+
(uint64_t*)sorted_cache_sets.data_ptr<int64_t>(),
215+
PTA_B(cache_set_sorted_unique_indices, int64_t, 1, 32),
216+
unique_indices_length.data_ptr<int32_t>(),
217+
PTA_B(lxu_cache_state, int64_t, 2, 32),
218+
PTA_B(lxu_cache_weights, cache_t, 2, 64),
219+
PTA_B(lfu_state, int64_t, 1, 64),
220+
stochastic_rounding_,
221+
rng_engine_inputs);
230222
}));
231223
}
232224

fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -228,39 +228,30 @@ void lru_cache_insert_cuda(
228228
? div_round_up(get_device_sm_cnt_(), ALL_TO_PREFETCH_SM_RATIO)
229229
: div_round_up(N, kMaxThreads / kWarpSize);
230230

231-
#ifdef FBGEMM_GPU_MEMCHECK
232-
const char* func_name = "lru_cache_insert_kernel";
233-
#endif
234-
lru_cache_insert_kernel<emb_t, cache_t>
235-
<<<grid_size,
236-
dim3(kWarpSize, kMaxThreads / kWarpSize),
237-
0,
238-
at::cuda::getCurrentCUDAStream()>>>(
239-
MAKE_PTA_WITH_NAME(func_name, weights, emb_t, 1, 64),
240-
MAKE_PTA_WITH_NAME(
241-
func_name, cache_hash_size_cumsum, int64_t, 1, 32),
242-
MAKE_PTA_WITH_NAME(
243-
func_name, cache_index_table_map, int32_t, 1, 64),
244-
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
245-
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
246-
MAKE_PTA_WITH_NAME(
247-
func_name, sorted_cache_sets, int32_t, 1, 32),
248-
MAKE_PTA_WITH_NAME(
249-
func_name, cache_set_sorted_unique_indices, int64_t, 1, 32),
250-
unique_indices_length.data_ptr<int32_t>(),
251-
MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32),
252-
MAKE_PTA_WITH_NAME(
253-
func_name, lxu_cache_weights, cache_t, 2, 64),
254-
time_stamp,
255-
MAKE_PTA_WITH_NAME(func_name, lru_state, int64_t, 2, 32),
256-
stochastic_rounding_,
257-
rng_engine_inputs,
258-
gather_cache_stats,
259-
MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats, int32_t, 1, 32),
260-
lock_cache_line,
261-
MAKE_PTA_WITH_NAME(
262-
func_name, lxu_cache_locking_counter, int32_t, 2, 32));
263-
C10_CUDA_KERNEL_LAUNCH_CHECK();
231+
FBGEMM_LAUNCH_KERNEL(
232+
(lru_cache_insert_kernel<emb_t, cache_t>),
233+
grid_size,
234+
dim3(kWarpSize, kMaxThreads / kWarpSize),
235+
0,
236+
at::cuda::getCurrentCUDAStream(),
237+
PTA_B(weights, emb_t, 1, 64),
238+
PTA_B(cache_hash_size_cumsum, int64_t, 1, 32),
239+
PTA_B(cache_index_table_map, int32_t, 1, 64),
240+
PTA_B(weights_offsets, int64_t, 1, 32),
241+
PTA_B(D_offsets, int32_t, 1, 32),
242+
PTA_B(sorted_cache_sets, int32_t, 1, 32),
243+
PTA_B(cache_set_sorted_unique_indices, int64_t, 1, 32),
244+
unique_indices_length.data_ptr<int32_t>(),
245+
PTA_B(lxu_cache_state, int64_t, 2, 32),
246+
PTA_B(lxu_cache_weights, cache_t, 2, 64),
247+
time_stamp,
248+
PTA_B(lru_state, int64_t, 2, 32),
249+
stochastic_rounding_,
250+
rng_engine_inputs,
251+
gather_cache_stats,
252+
PTA_B(uvm_cache_stats, int32_t, 1, 32),
253+
lock_cache_line,
254+
PTA_B(lxu_cache_locking_counter, int32_t, 2, 32));
264255
}));
265256
}
266257

fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu

Lines changed: 35 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -126,24 +126,22 @@ DLL_PUBLIC void lxu_cache_flush_cuda(
126126
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)
127127
->philox_cuda_state(4);
128128
}
129-
#ifdef FBGEMM_GPU_MEMCHECK
130-
const char* func_name = "lxu_cache_flush_kernel";
131-
#endif
132-
lxu_cache_flush_kernel<emb_t, cache_t>
133-
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
134-
MAKE_PTA_WITH_NAME(func_name, uvm_weights, emb_t, 1, 64),
135-
MAKE_PTA_WITH_NAME(
136-
func_name, cache_hash_size_cumsum, int64_t, 1, 32),
137-
MAKE_PTA_WITH_NAME(
138-
func_name, cache_index_table_map, int32_t, 1, 64),
139-
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
140-
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
141-
MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32),
142-
MAKE_PTA_WITH_NAME(
143-
func_name, lxu_cache_weights, cache_t, 2, 64),
144-
stochastic_rounding_,
145-
rng_engine_inputs);
146-
C10_CUDA_KERNEL_LAUNCH_CHECK();
129+
130+
FBGEMM_LAUNCH_KERNEL(
131+
(lxu_cache_flush_kernel<emb_t, cache_t>),
132+
blocks,
133+
threads,
134+
0,
135+
at::cuda::getCurrentCUDAStream(),
136+
PTA_B(uvm_weights, emb_t, 1, 64),
137+
PTA_B(cache_hash_size_cumsum, int64_t, 1, 32),
138+
PTA_B(cache_index_table_map, int32_t, 1, 64),
139+
PTA_B(weights_offsets, int64_t, 1, 32),
140+
PTA_B(D_offsets, int32_t, 1, 32),
141+
PTA_B(lxu_cache_state, int64_t, 2, 32),
142+
PTA_B(lxu_cache_weights, cache_t, 2, 64),
143+
stochastic_rounding_,
144+
rng_engine_inputs);
147145
}));
148146
}
149147

@@ -211,34 +209,26 @@ void lxu_cache_locking_counter_decrement_cuda(
211209
div_round_up(N, kMaxThreads),
212210
get_max_thread_blocks_for_cache_kernels_()));
213211

214-
#ifdef FBGEMM_GPU_MEMCHECK
215-
const char* func_name = "lxu_cache_locations_count_kernel";
216-
#endif
217-
218-
lxu_cache_locations_count_kernel<<<
212+
FBGEMM_LAUNCH_KERNEL(
213+
lxu_cache_locations_count_kernel,
219214
blocks,
220215
kMaxThreads,
221216
0,
222-
at::cuda::getCurrentCUDAStream()>>>(
217+
at::cuda::getCurrentCUDAStream(),
223218
MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32),
224219
MAKE_PTA_WITH_NAME(func_name, count, int32_t, 2, 32),
225220
fd);
226-
C10_CUDA_KERNEL_LAUNCH_CHECK();
227-
228-
#ifdef FBGEMM_GPU_MEMCHECK
229-
const char* func_name2 = "lxu_cache_locking_counter_decrement_kernel";
230-
#endif
231221

232-
lxu_cache_locking_counter_decrement_kernel<<<
222+
FBGEMM_LAUNCH_KERNEL(
223+
lxu_cache_locking_counter_decrement_kernel,
233224
std::min(
234225
div_round_up(C, kMaxThreads / kWarpSize),
235226
get_max_thread_blocks_for_cache_kernels_()),
236227
dim3(kWarpSize, kMaxThreads / kWarpSize),
237228
0,
238-
at::cuda::getCurrentCUDAStream()>>>(
229+
at::cuda::getCurrentCUDAStream(),
239230
MAKE_PTA_WITH_NAME(func_name2, lxu_cache_locking_counter, int32_t, 2, 32),
240231
MAKE_PTA_WITH_NAME(func_name2, count, int32_t, 2, 32));
241-
C10_CUDA_KERNEL_LAUNCH_CHECK();
242232
}
243233

244234
namespace {
@@ -445,14 +435,12 @@ DLL_PUBLIC Tensor lxu_cache_lookup_cuda(
445435

446436
AT_DISPATCH_INDEX_TYPES(
447437
linear_cache_indices.scalar_type(), "lxu_cache_lookup_cuda", [&] {
448-
#ifdef FBGEMM_GPU_MEMCHECK
449-
const char* func_name = "lxu_cache_lookup_kernel";
450-
#endif
451-
lxu_cache_lookup_kernel<<<
438+
FBGEMM_LAUNCH_KERNEL(
439+
(lxu_cache_lookup_kernel<index_t>),
452440
blocks,
453441
threads,
454442
0,
455-
at::cuda::getCurrentCUDAStream()>>>(
443+
at::cuda::getCurrentCUDAStream(),
456444
MAKE_PTA_WITH_NAME(func_name, linear_cache_indices, index_t, 1, 32),
457445
MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32),
458446
invalid_index,
@@ -462,7 +450,6 @@ DLL_PUBLIC Tensor lxu_cache_lookup_cuda(
462450
num_uniq_cache_indices.has_value()
463451
? num_uniq_cache_indices.value().data_ptr<int32_t>()
464452
: nullptr);
465-
C10_CUDA_KERNEL_LAUNCH_CHECK();
466453
});
467454
return lxu_cache_locations;
468455
}
@@ -499,21 +486,18 @@ DLL_PUBLIC Tensor direct_mapped_lxu_cache_lookup_cuda(
499486
linear_cache_indices.scalar_type(),
500487
"direct_mapped_lxu_cache_lookup_cuda",
501488
[&] {
502-
#ifdef FBGEMM_GPU_MEMCHECK
503-
const char* func_name = "direct_mapped_lxu_cache_lookup_kernel";
504-
#endif
505-
direct_mapped_lxu_cache_lookup_kernel<<<
489+
FBGEMM_LAUNCH_KERNEL(
490+
(direct_mapped_lxu_cache_lookup_kernel<index_t>),
506491
blocks,
507492
kMaxThreads,
508493
0,
509-
at::cuda::getCurrentCUDAStream()>>>(
510-
MAKE_PTA_WITH_NAME(func_name, linear_cache_indices, index_t, 1, 32),
511-
MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32),
494+
at::cuda::getCurrentCUDAStream(),
495+
PTA_B(linear_cache_indices, index_t, 1, 32),
496+
PTA_B(lxu_cache_state, int64_t, 2, 32),
512497
invalid_index,
513-
MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32),
498+
PTA_B(lxu_cache_locations, int32_t, 1, 32),
514499
gather_cache_stats,
515-
MAKE_PTA_WITH_NAME(func_name, uvm_cache_stats_, int32_t, 1, 32));
516-
C10_CUDA_KERNEL_LAUNCH_CHECK();
500+
PTA_B(uvm_cache_stats_, int32_t, 1, 32));
517501
});
518502

519503
return lxu_cache_locations;
@@ -559,21 +543,17 @@ DLL_PUBLIC void lxu_cache_locations_update_cuda(
559543
div_round_up(N, kMaxThreads),
560544
get_max_thread_blocks_for_cache_kernels_()));
561545

562-
#ifdef FBGEMM_GPU_MEMCHECK
563-
const char* func_name = "lxu_cache_locations_update_kernel";
564-
#endif
565-
566-
lxu_cache_locations_update_kernel<<<
546+
FBGEMM_LAUNCH_KERNEL(
547+
lxu_cache_locations_update_kernel,
567548
blocks,
568549
kMaxThreads,
569550
0,
570-
at::cuda::getCurrentCUDAStream()>>>(
551+
at::cuda::getCurrentCUDAStream(),
571552
MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32),
572553
MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations_new, int32_t, 1, 32),
573554
num_uniq_cache_indices.has_value()
574555
? num_uniq_cache_indices.value().data_ptr<int32_t>()
575556
: nullptr);
576557

577-
C10_CUDA_KERNEL_LAUNCH_CHECK();
578558
return;
579559
}

fbgemm_gpu/src/split_embeddings_cache/reset_weight_momentum.cu

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ __global__ __launch_bounds__(kMaxThreads) void get_cache_indices_kernel(
3535
linear_cache_indices) {
3636
const int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
3737

38-
const int32_t t_i = blockIdx.x / blocks_per_table;
39-
const int32_t threads_per_table = blocks_per_table * blockDim.x;
38+
const auto t_i = blockIdx.x / blocks_per_table;
39+
const auto threads_per_table = blocks_per_table * blockDim.x;
4040
const int32_t idx_table = index % threads_per_table;
4141
const int32_t logical_id = logical_table_ids[t_i];
4242
const int32_t buffer_id = buffer_ids[t_i];
@@ -112,7 +112,7 @@ __global__ __launch_bounds__(kMaxThreads) void reset_weight_momentum_kernel(
112112
lxu_cache_locations) {
113113
const int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
114114

115-
const int32_t t_i = blockIdx.x / blocks_per_table;
115+
const auto t_i = blockIdx.x / blocks_per_table;
116116
const int32_t buffer_id = buffer_ids[t_i];
117117
const int64_t num_indices =
118118
pruned_indices_offsets[buffer_id + 1] - pruned_indices_offsets[buffer_id];
@@ -126,7 +126,7 @@ __global__ __launch_bounds__(kMaxThreads) void reset_weight_momentum_kernel(
126126
const int32_t chunk4s_per_row = D / 4;
127127
const int64_t total_chunk4s_per_table = num_indices * chunk4s_per_row;
128128

129-
const int32_t threads_per_table = blocks_per_table * blockDim.x;
129+
const auto threads_per_table = blocks_per_table * blockDim.x;
130130
const int64_t chunk4s_per_thread =
131131
div_round_up(total_chunk4s_per_table, threads_per_table);
132132
const int32_t idx_table = index % threads_per_table;
@@ -249,23 +249,19 @@ DLL_PUBLIC void reset_weight_momentum_cuda(
249249
auto linear_cache_indices = at::zeros(
250250
{num_pruned_indices}, pruned_indices.options().dtype(at::kLong));
251251

252-
#ifdef FBGEMM_GPU_MEMCHECK
253-
const char* func_name = "get_cache_indices_kernel";
254-
#endif
255-
256-
get_cache_indices_kernel<<<
252+
FBGEMM_LAUNCH_KERNEL(
253+
get_cache_indices_kernel,
257254
num_pruned_tables * blocks_per_table,
258255
kMaxThreads,
259256
0,
260-
at::cuda::getCurrentCUDAStream()>>>(
257+
at::cuda::getCurrentCUDAStream(),
261258
blocks_per_table,
262-
MAKE_PTA_WITH_NAME(func_name, cache_hash_size_cumsum, int64_t, 1, 32),
263-
MAKE_PTA_WITH_NAME(func_name, pruned_indices, int64_t, 1, 32),
264-
MAKE_PTA_WITH_NAME(func_name, pruned_indices_offsets, int64_t, 1, 32),
265-
MAKE_PTA_WITH_NAME(func_name, logical_table_ids, int32_t, 1, 32),
266-
MAKE_PTA_WITH_NAME(func_name, buffer_ids, int32_t, 1, 32),
267-
MAKE_PTA_WITH_NAME(func_name, linear_cache_indices, int64_t, 1, 32));
268-
C10_CUDA_KERNEL_LAUNCH_CHECK();
259+
PTA_B(cache_hash_size_cumsum, int64_t, 1, 32),
260+
PTA_B(pruned_indices, int64_t, 1, 32),
261+
PTA_B(pruned_indices_offsets, int64_t, 1, 32),
262+
PTA_B(logical_table_ids, int32_t, 1, 32),
263+
PTA_B(buffer_ids, int32_t, 1, 32),
264+
PTA_B(linear_cache_indices, int64_t, 1, 32));
269265

270266
// Look up cache locations
271267
Tensor uvm_cache_stats =

0 commit comments

Comments
 (0)