@@ -31,6 +31,17 @@ __inline__ constexpr int ceil_of_ratio(int a, int b) {
31
31
return (a + b - 1 ) / b;
32
32
};
33
33
34
+ template <typename T>
35
+ __inline__ T* get_ptr (std::optional<at::Tensor> tensor) {
36
+ return reinterpret_cast <T*>(
37
+ tensor.has_value () ? tensor->data_ptr () : nullptr );
38
+ };
39
+
40
+ template <typename T>
41
+ __inline__ __device__ T get_item (const T* ptr, const T& default_value) {
42
+ return ptr != nullptr ? *ptr : default_value;
43
+ };
44
+
34
45
#ifdef USE_ROCM
35
46
__device__ __forceinline__ int atomic_add_relaxed (int * addr, int inc) {
36
47
return __hip_atomic_fetch_add (
@@ -71,24 +82,29 @@ __device__ __forceinline__ int load_aquire(int* addr) {
71
82
72
83
template <class DataType , class IndexType , int NumExperts, int NumTokensPerTile>
73
84
struct SharedStorage {
74
- DataType scores [NumTokensPerTile * NumExperts];
85
+ DataType routing_scores [NumTokensPerTile * NumExperts];
75
86
IndexType expert_indices[NumTokensPerTile * NumExperts];
76
- IndexType expert_count_cumsums [NumExperts];
87
+ IndexType token_count_cumsums [NumExperts];
77
88
};
78
89
79
90
template <class DataType , class IndexType >
80
91
struct Params {
81
- const DataType* scores;
82
- int stride_t_;
83
- int stride_e_;
84
- int num_tokens;
85
- int num_tokens_per_cta;
86
- IndexType* counts;
87
- IndexType* expert_indices;
88
- IndexType* token_indices;
92
+ // Inputs
93
+ const DataType* routing_scores;
94
+ const int stride_t ;
95
+ const int stride_e;
96
+ const IndexType* valid_token_count;
97
+ const int num_tokens;
98
+ const int num_tokens_per_cta;
99
+
100
+ // Buffer
101
+ IndexType* buffered_expert_indices;
102
+ IndexType* buffered_token_indices;
103
+
104
+ // Outputs
105
+ IndexType* token_count_per_expert;
89
106
IndexType* shuffled_expert_indices;
90
107
IndexType* shuffled_token_indices;
91
- IndexType* num_valid_tokens;
92
108
};
93
109
94
110
template <class DataType , class IndexType , int NumExperts, int NumTokensPerTile>
@@ -106,7 +122,7 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
106
122
107
123
const int num_total_tokens = params.num_tokens ;
108
124
const int num_valid_tokens =
109
- params.num_valid_tokens ? *params. num_valid_tokens : num_total_tokens;
125
+ get_item ( params.valid_token_count , num_total_tokens) ;
110
126
111
127
const int token_index_offset_start = bidx * params.num_tokens_per_cta ;
112
128
const int token_index_offset_end = std::min (
@@ -116,8 +132,8 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
116
132
return ;
117
133
}
118
134
119
- const int stride_t_ = params.stride_t_ ;
120
- const int stride_e_ = params.stride_e_ ;
135
+ const int stride_t = params.stride_t ;
136
+ const int stride_e = params.stride_e ;
121
137
122
138
for (int token_index_offset = token_index_offset_start;
123
139
token_index_offset < token_index_offset_end;
@@ -129,8 +145,9 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
129
145
int token_index = token_index_offset + i / NumExperts;
130
146
int expert_index = i % NumExperts;
131
147
132
- smem.scores [i] = token_index < num_valid_tokens
133
- ? params.scores [token_index * stride_t_ + expert_index * stride_e_]
148
+ smem.routing_scores [i] = token_index < num_valid_tokens
149
+ ? params.routing_scores
150
+ [token_index * stride_t + expert_index * stride_e]
134
151
: static_cast <DataType>(0 .0f );
135
152
smem.expert_indices [i] = expert_index;
136
153
}
@@ -160,13 +177,14 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
160
177
(tidx % kNumParallelReductionThreads ) * 2 ;
161
178
int rhs_smem_index = lhs_smem_index + num_reduced_threads;
162
179
163
- auto lhs_score = smem.scores [lhs_smem_index];
164
- auto rhs_score = smem.scores [rhs_smem_index];
180
+ auto lhs_score = smem.routing_scores [lhs_smem_index];
181
+ auto rhs_score = smem.routing_scores [rhs_smem_index];
165
182
auto lhs_expert_index = smem.expert_indices [lhs_smem_index];
166
183
auto rhs_expert_index = smem.expert_indices [rhs_smem_index];
167
184
168
185
bool lhs_larger = lhs_score >= rhs_score;
169
- smem.scores [lhs_smem_index] = lhs_larger ? lhs_score : rhs_score;
186
+ smem.routing_scores [lhs_smem_index] =
187
+ lhs_larger ? lhs_score : rhs_score;
170
188
smem.expert_indices [lhs_smem_index] =
171
189
lhs_larger ? lhs_expert_index : rhs_expert_index;
172
190
}
@@ -193,17 +211,17 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
193
211
if (token_index < num_valid_tokens) {
194
212
auto expert_index = smem.expert_indices [local_token_index * NumExperts];
195
213
auto token_index_in_expert =
196
- atomic_add_relaxed (¶ms.counts [expert_index], 1 );
197
- params.expert_indices [token_index] = expert_index;
198
- params.token_indices [token_index] = token_index_in_expert;
214
+ atomic_add_relaxed (¶ms.token_count_per_expert [expert_index], 1 );
215
+ params.buffered_expert_indices [token_index] = expert_index;
216
+ params.buffered_token_indices [token_index] = token_index_in_expert;
199
217
}
200
218
}
201
219
__syncthreads ();
202
220
}
203
221
204
222
if (tidx == 0 ) {
205
223
int processed_tokens = 0 ;
206
- int * processed_tokens_addr = ¶ms.counts [NumExperts];
224
+ int * processed_tokens_addr = ¶ms.token_count_per_expert [NumExperts];
207
225
208
226
int inc = token_index_offset_end - token_index_offset_start;
209
227
atomic_add_release (processed_tokens_addr, inc);
@@ -217,15 +235,15 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
217
235
// 4. Scan
218
236
static_assert (kNumThreads >= NumExperts, " " );
219
237
if (tidx < NumExperts) {
220
- smem.expert_count_cumsums [tidx] = params.counts [tidx];
238
+ smem.token_count_cumsums [tidx] = params.token_count_per_expert [tidx];
221
239
}
222
240
__syncthreads ();
223
241
224
242
if (tidx == 0 ) {
225
243
// TODO(shikaili): parallel.
226
244
#pragma unroll
227
245
for (int i = 1 ; i < NumExperts; ++i) {
228
- smem.expert_count_cumsums [i] += smem.expert_count_cumsums [i - 1 ];
246
+ smem.token_count_cumsums [i] += smem.token_count_cumsums [i - 1 ];
229
247
}
230
248
}
231
249
__syncthreads ();
@@ -236,11 +254,10 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
236
254
global_token_offset += kNumThreads ) {
237
255
int token_index = global_token_offset + tidx;
238
256
if (token_index < num_valid_tokens) {
239
- int expert_index = params.expert_indices [token_index];
240
- int token_index_in_expert = params.token_indices [token_index];
257
+ int expert_index = params.buffered_expert_indices [token_index];
258
+ int token_index_in_expert = params.buffered_token_indices [token_index];
241
259
int new_token_index =
242
- (expert_index == 0 ? 0
243
- : smem.expert_count_cumsums [expert_index - 1 ]) +
260
+ (expert_index == 0 ? 0 : smem.token_count_cumsums [expert_index - 1 ]) +
244
261
token_index_in_expert;
245
262
params.shuffled_expert_indices [new_token_index] = expert_index;
246
263
params.shuffled_token_indices [new_token_index] = token_index;
@@ -255,42 +272,37 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
255
272
} // namespace
256
273
257
274
std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch (
258
- const at::Tensor& scores ,
259
- std::optional<at::Tensor> num_valid_tokens ) {
260
- TORCH_CHECK (scores .dtype () == torch::kBFloat16 );
275
+ const at::Tensor& routing_scores ,
276
+ std::optional<at::Tensor> valid_token_count ) {
277
+ TORCH_CHECK (routing_scores .dtype () == torch::kBFloat16 );
261
278
using DataType = __nv_bfloat16;
262
279
using IndexType = int32_t ;
263
280
264
- TORCH_CHECK (scores .dim () == 2 );
265
- const int num_tokens = scores .size (0 );
266
- const int num_experts = scores .size (1 );
281
+ TORCH_CHECK (routing_scores .dim () == 2 );
282
+ const int num_tokens = routing_scores .size (0 );
283
+ const int num_experts = routing_scores .size (1 );
267
284
TORCH_CHECK (num_experts == 16 || num_experts == 128 );
268
285
269
286
auto allocate_index_tensor = [&](int size) {
270
287
return at::empty (
271
- {size}, at::TensorOptions ().dtype (at::kInt ).device (scores.device ()));
288
+ {size},
289
+ at::TensorOptions ().dtype (at::kInt ).device (routing_scores.device ()));
272
290
};
273
- at::Tensor counts = allocate_index_tensor (num_experts + 1 );
274
- at::Tensor expert_indices = allocate_index_tensor (num_tokens);
275
- at::Tensor token_indices = allocate_index_tensor (num_tokens);
291
+ at::Tensor token_count_per_expert = allocate_index_tensor (num_experts + 1 );
276
292
at::Tensor shuffled_expert_indices = allocate_index_tensor (num_tokens);
277
293
at::Tensor shuffled_token_indices = allocate_index_tensor (num_tokens);
294
+ at::Tensor buffered_expert_indices = allocate_index_tensor (num_tokens);
295
+ at::Tensor buffered_token_indices = allocate_index_tensor (num_tokens);
278
296
279
297
#ifdef USE_ROCM
280
- counts.zero_ ();
281
298
// TODO(shikaili): hipMetsetAsync is more expensive than ATen set zero.
282
- /*
283
- hipMemsetAsync(
284
- counts.data_ptr(),
285
- 0,
286
- counts.numel() * counts.dtype().itemsize(),
287
- at::cuda::getCurrentCUDAStream());
288
- */
299
+ token_count_per_expert.zero_ ();
289
300
#else
290
301
cudaMemsetAsync (
291
- counts .data_ptr (),
302
+ token_count_per_expert .data_ptr (),
292
303
0 ,
293
- counts.numel () * counts.dtype ().itemsize (),
304
+ token_count_per_expert.numel () *
305
+ token_count_per_expert.dtype ().itemsize (),
294
306
at::cuda::getCurrentCUDAStream ());
295
307
#endif
296
308
@@ -323,26 +335,30 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
323
335
324
336
const int num_tiles = ceil_of_ratio (num_tokens, kNumTokensPerTile );
325
337
const int num_ctas = std::min (num_tiles, num_sms);
326
-
327
- int num_tokens_per_cta = ceil_of_ratio (num_tokens, num_ctas);
328
338
const int num_tiles_per_cta =
329
- ceil_of_ratio (num_tokens_per_cta , kNumTokensPerTile );
330
- num_tokens_per_cta = num_tiles_per_cta * kNumTokensPerTile ;
339
+ ceil_of_ratio (ceil_of_ratio (num_tokens, num_ctas) , kNumTokensPerTile );
340
+ const int num_tokens_per_cta = num_tiles_per_cta * kNumTokensPerTile ;
331
341
332
342
Params<DataType, IndexType> params = {
333
- reinterpret_cast <DataType*>(scores.data_ptr ()),
334
- static_cast <int >(scores.stride (0 )),
335
- static_cast <int >(scores.stride (1 )),
336
- num_tokens,
337
- num_tokens_per_cta,
338
- reinterpret_cast <IndexType*>(counts.data_ptr ()),
339
- reinterpret_cast <IndexType*>(expert_indices.data_ptr ()),
340
- reinterpret_cast <IndexType*>(token_indices.data_ptr ()),
341
- reinterpret_cast <IndexType*>(shuffled_expert_indices.data_ptr ()),
342
- reinterpret_cast <IndexType*>(shuffled_token_indices.data_ptr ()),
343
- reinterpret_cast <IndexType*>(
344
- num_valid_tokens.has_value () ? num_valid_tokens->data_ptr ()
345
- : nullptr )};
343
+ // Inputs
344
+ .routing_scores = reinterpret_cast <DataType*>(routing_scores.data_ptr ()),
345
+ .stride_t = static_cast <int >(routing_scores.stride (0 )),
346
+ .stride_e = static_cast <int >(routing_scores.stride (1 )),
347
+ .valid_token_count = get_ptr<IndexType>(valid_token_count),
348
+ .num_tokens = num_tokens,
349
+ .num_tokens_per_cta = num_tokens_per_cta,
350
+ // Buffer
351
+ .buffered_expert_indices =
352
+ reinterpret_cast <IndexType*>(buffered_expert_indices.data_ptr ()),
353
+ .buffered_token_indices =
354
+ reinterpret_cast <IndexType*>(buffered_token_indices.data_ptr ()),
355
+ // Outputs
356
+ .token_count_per_expert =
357
+ reinterpret_cast <IndexType*>(token_count_per_expert.data_ptr ()),
358
+ .shuffled_expert_indices =
359
+ reinterpret_cast <IndexType*>(shuffled_expert_indices.data_ptr ()),
360
+ .shuffled_token_indices =
361
+ reinterpret_cast <IndexType*>(shuffled_token_indices.data_ptr ())};
346
362
347
363
dim3 grids (num_ctas);
348
364
dim3 blocks (kNumThreads );
@@ -360,7 +376,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
360
376
#endif
361
377
362
378
return std::make_tuple (
363
- counts , shuffled_expert_indices, shuffled_token_indices);
379
+ token_count_per_expert , shuffled_expert_indices, shuffled_token_indices);
364
380
}
365
381
366
382
} // namespace fbgemm_gpu
0 commit comments