Skip to content

Commit 52f07e7

Browse files
levendleefacebook-github-bot
authored andcommitted
Clean up IndexShuffling op. (#4155)
Summary: Pull Request resolved: #4155 X-link: facebookresearch/FBGEMM#1235 Clean up IndexShuffling op. Better naming and more readable. Reviewed By: jasonjk-park Differential Revision: D75039953 fbshipit-source-id: 201662a7a13aec8b7ac64c1c7424ddee76f2e02a
1 parent 39313bb commit 52f07e7

File tree

2 files changed

+97
-78
lines changed

2 files changed

+97
-78
lines changed

fbgemm_gpu/experimental/gen_ai/src/moe/index_shuffling.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,27 @@
1414
namespace fbgemm_gpu {
1515

1616
std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
17-
const at::Tensor& scores,
18-
std::optional<at::Tensor> num_valid_tokens);
17+
const at::Tensor& routing_scores,
18+
std::optional<at::Tensor> valid_token_count);
1919

2020
std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch_meta(
21-
const at::Tensor& scores,
22-
std::optional<at::Tensor> num_valid_tokens) {
23-
int T = scores.size(0);
24-
int E = scores.size(1);
25-
at::Tensor counts = at::empty({E + 1}, scores.options().dtype(at::kInt));
26-
at::Tensor expert_indices = at::empty({T}, scores.options().dtype(at::kInt));
27-
at::Tensor token_indices = at::empty({T}, scores.options().dtype(at::kInt));
28-
return {counts, expert_indices, token_indices};
21+
const at::Tensor& routing_scores,
22+
std::optional<at::Tensor> valid_token_count) {
23+
int T = routing_scores.size(0);
24+
int E = routing_scores.size(1);
25+
at::Tensor token_counts_per_expert =
26+
at::empty({E + 1}, routing_scores.options().dtype(at::kInt));
27+
at::Tensor expert_indices =
28+
at::empty({T}, routing_scores.options().dtype(at::kInt));
29+
at::Tensor token_indices =
30+
at::empty({T}, routing_scores.options().dtype(at::kInt));
31+
return {token_counts_per_expert, expert_indices, token_indices};
2932
}
3033

3134
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
3235
m.set_python_module("fbgemm_gpu.experimental.gen_ai.moe");
3336
m.def(
34-
"index_shuffling(Tensor scores, Tensor? num_valid_tokens= None) -> (Tensor, Tensor, Tensor)");
37+
"index_shuffling(Tensor routing_scores, Tensor? valid_token_count=None) -> (Tensor, Tensor, Tensor)");
3538
}
3639

3740
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {

fbgemm_gpu/experimental/gen_ai/src/moe/index_shuffling.cu

Lines changed: 83 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@ __inline__ constexpr int ceil_of_ratio(int a, int b) {
3131
return (a + b - 1) / b;
3232
};
3333

34+
template <typename T>
35+
__inline__ T* get_ptr(std::optional<at::Tensor> tensor) {
36+
return reinterpret_cast<T*>(
37+
tensor.has_value() ? tensor->data_ptr() : nullptr);
38+
};
39+
40+
template <typename T>
41+
__inline__ __device__ T get_item(const T* ptr, const T& default_value) {
42+
return ptr != nullptr ? *ptr : default_value;
43+
};
44+
3445
#ifdef USE_ROCM
3546
__device__ __forceinline__ int atomic_add_relaxed(int* addr, int inc) {
3647
return __hip_atomic_fetch_add(
@@ -71,24 +82,29 @@ __device__ __forceinline__ int load_aquire(int* addr) {
7182

7283
template <class DataType, class IndexType, int NumExperts, int NumTokensPerTile>
7384
struct SharedStorage {
74-
DataType scores[NumTokensPerTile * NumExperts];
85+
DataType routing_scores[NumTokensPerTile * NumExperts];
7586
IndexType expert_indices[NumTokensPerTile * NumExperts];
76-
IndexType expert_count_cumsums[NumExperts];
87+
IndexType token_count_cumsums[NumExperts];
7788
};
7889

7990
template <class DataType, class IndexType>
8091
struct Params {
81-
const DataType* scores;
82-
int stride_t_;
83-
int stride_e_;
84-
int num_tokens;
85-
int num_tokens_per_cta;
86-
IndexType* counts;
87-
IndexType* expert_indices;
88-
IndexType* token_indices;
92+
// Inputs
93+
const DataType* routing_scores;
94+
const int stride_t;
95+
const int stride_e;
96+
const IndexType* valid_token_count;
97+
const int num_tokens;
98+
const int num_tokens_per_cta;
99+
100+
// Buffer
101+
IndexType* buffered_expert_indices;
102+
IndexType* buffered_token_indices;
103+
104+
// Outputs
105+
IndexType* token_count_per_expert;
89106
IndexType* shuffled_expert_indices;
90107
IndexType* shuffled_token_indices;
91-
IndexType* num_valid_tokens;
92108
};
93109

94110
template <class DataType, class IndexType, int NumExperts, int NumTokensPerTile>
@@ -106,7 +122,7 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
106122

107123
const int num_total_tokens = params.num_tokens;
108124
const int num_valid_tokens =
109-
params.num_valid_tokens ? *params.num_valid_tokens : num_total_tokens;
125+
get_item(params.valid_token_count, num_total_tokens);
110126

111127
const int token_index_offset_start = bidx * params.num_tokens_per_cta;
112128
const int token_index_offset_end = std::min(
@@ -116,8 +132,8 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
116132
return;
117133
}
118134

119-
const int stride_t_ = params.stride_t_;
120-
const int stride_e_ = params.stride_e_;
135+
const int stride_t = params.stride_t;
136+
const int stride_e = params.stride_e;
121137

122138
for (int token_index_offset = token_index_offset_start;
123139
token_index_offset < token_index_offset_end;
@@ -129,8 +145,9 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
129145
int token_index = token_index_offset + i / NumExperts;
130146
int expert_index = i % NumExperts;
131147

132-
smem.scores[i] = token_index < num_valid_tokens
133-
? params.scores[token_index * stride_t_ + expert_index * stride_e_]
148+
smem.routing_scores[i] = token_index < num_valid_tokens
149+
? params.routing_scores
150+
[token_index * stride_t + expert_index * stride_e]
134151
: static_cast<DataType>(0.0f);
135152
smem.expert_indices[i] = expert_index;
136153
}
@@ -160,13 +177,14 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
160177
(tidx % kNumParallelReductionThreads) * 2;
161178
int rhs_smem_index = lhs_smem_index + num_reduced_threads;
162179

163-
auto lhs_score = smem.scores[lhs_smem_index];
164-
auto rhs_score = smem.scores[rhs_smem_index];
180+
auto lhs_score = smem.routing_scores[lhs_smem_index];
181+
auto rhs_score = smem.routing_scores[rhs_smem_index];
165182
auto lhs_expert_index = smem.expert_indices[lhs_smem_index];
166183
auto rhs_expert_index = smem.expert_indices[rhs_smem_index];
167184

168185
bool lhs_larger = lhs_score >= rhs_score;
169-
smem.scores[lhs_smem_index] = lhs_larger ? lhs_score : rhs_score;
186+
smem.routing_scores[lhs_smem_index] =
187+
lhs_larger ? lhs_score : rhs_score;
170188
smem.expert_indices[lhs_smem_index] =
171189
lhs_larger ? lhs_expert_index : rhs_expert_index;
172190
}
@@ -193,17 +211,17 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
193211
if (token_index < num_valid_tokens) {
194212
auto expert_index = smem.expert_indices[local_token_index * NumExperts];
195213
auto token_index_in_expert =
196-
atomic_add_relaxed(&params.counts[expert_index], 1);
197-
params.expert_indices[token_index] = expert_index;
198-
params.token_indices[token_index] = token_index_in_expert;
214+
atomic_add_relaxed(&params.token_count_per_expert[expert_index], 1);
215+
params.buffered_expert_indices[token_index] = expert_index;
216+
params.buffered_token_indices[token_index] = token_index_in_expert;
199217
}
200218
}
201219
__syncthreads();
202220
}
203221

204222
if (tidx == 0) {
205223
int processed_tokens = 0;
206-
int* processed_tokens_addr = &params.counts[NumExperts];
224+
int* processed_tokens_addr = &params.token_count_per_expert[NumExperts];
207225

208226
int inc = token_index_offset_end - token_index_offset_start;
209227
atomic_add_release(processed_tokens_addr, inc);
@@ -217,15 +235,15 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
217235
// 4. Scan
218236
static_assert(kNumThreads >= NumExperts, "");
219237
if (tidx < NumExperts) {
220-
smem.expert_count_cumsums[tidx] = params.counts[tidx];
238+
smem.token_count_cumsums[tidx] = params.token_count_per_expert[tidx];
221239
}
222240
__syncthreads();
223241

224242
if (tidx == 0) {
225243
// TODO(shikaili): parallel.
226244
#pragma unroll
227245
for (int i = 1; i < NumExperts; ++i) {
228-
smem.expert_count_cumsums[i] += smem.expert_count_cumsums[i - 1];
246+
smem.token_count_cumsums[i] += smem.token_count_cumsums[i - 1];
229247
}
230248
}
231249
__syncthreads();
@@ -236,11 +254,10 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
236254
global_token_offset += kNumThreads) {
237255
int token_index = global_token_offset + tidx;
238256
if (token_index < num_valid_tokens) {
239-
int expert_index = params.expert_indices[token_index];
240-
int token_index_in_expert = params.token_indices[token_index];
257+
int expert_index = params.buffered_expert_indices[token_index];
258+
int token_index_in_expert = params.buffered_token_indices[token_index];
241259
int new_token_index =
242-
(expert_index == 0 ? 0
243-
: smem.expert_count_cumsums[expert_index - 1]) +
260+
(expert_index == 0 ? 0 : smem.token_count_cumsums[expert_index - 1]) +
244261
token_index_in_expert;
245262
params.shuffled_expert_indices[new_token_index] = expert_index;
246263
params.shuffled_token_indices[new_token_index] = token_index;
@@ -255,42 +272,37 @@ __global__ void index_shuffling_kernel(Params<DataType, IndexType> params) {
255272
} // namespace
256273

257274
std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
258-
const at::Tensor& scores,
259-
std::optional<at::Tensor> num_valid_tokens) {
260-
TORCH_CHECK(scores.dtype() == torch::kBFloat16);
275+
const at::Tensor& routing_scores,
276+
std::optional<at::Tensor> valid_token_count) {
277+
TORCH_CHECK(routing_scores.dtype() == torch::kBFloat16);
261278
using DataType = __nv_bfloat16;
262279
using IndexType = int32_t;
263280

264-
TORCH_CHECK(scores.dim() == 2);
265-
const int num_tokens = scores.size(0);
266-
const int num_experts = scores.size(1);
281+
TORCH_CHECK(routing_scores.dim() == 2);
282+
const int num_tokens = routing_scores.size(0);
283+
const int num_experts = routing_scores.size(1);
267284
TORCH_CHECK(num_experts == 16 || num_experts == 128);
268285

269286
auto allocate_index_tensor = [&](int size) {
270287
return at::empty(
271-
{size}, at::TensorOptions().dtype(at::kInt).device(scores.device()));
288+
{size},
289+
at::TensorOptions().dtype(at::kInt).device(routing_scores.device()));
272290
};
273-
at::Tensor counts = allocate_index_tensor(num_experts + 1);
274-
at::Tensor expert_indices = allocate_index_tensor(num_tokens);
275-
at::Tensor token_indices = allocate_index_tensor(num_tokens);
291+
at::Tensor token_count_per_expert = allocate_index_tensor(num_experts + 1);
276292
at::Tensor shuffled_expert_indices = allocate_index_tensor(num_tokens);
277293
at::Tensor shuffled_token_indices = allocate_index_tensor(num_tokens);
294+
at::Tensor buffered_expert_indices = allocate_index_tensor(num_tokens);
295+
at::Tensor buffered_token_indices = allocate_index_tensor(num_tokens);
278296

279297
#ifdef USE_ROCM
280-
counts.zero_();
281298
// TODO(shikaili): hipMetsetAsync is more expensive than ATen set zero.
282-
/*
283-
hipMemsetAsync(
284-
counts.data_ptr(),
285-
0,
286-
counts.numel() * counts.dtype().itemsize(),
287-
at::cuda::getCurrentCUDAStream());
288-
*/
299+
token_count_per_expert.zero_();
289300
#else
290301
cudaMemsetAsync(
291-
counts.data_ptr(),
302+
token_count_per_expert.data_ptr(),
292303
0,
293-
counts.numel() * counts.dtype().itemsize(),
304+
token_count_per_expert.numel() *
305+
token_count_per_expert.dtype().itemsize(),
294306
at::cuda::getCurrentCUDAStream());
295307
#endif
296308

@@ -323,26 +335,30 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
323335

324336
const int num_tiles = ceil_of_ratio(num_tokens, kNumTokensPerTile);
325337
const int num_ctas = std::min(num_tiles, num_sms);
326-
327-
int num_tokens_per_cta = ceil_of_ratio(num_tokens, num_ctas);
328338
const int num_tiles_per_cta =
329-
ceil_of_ratio(num_tokens_per_cta, kNumTokensPerTile);
330-
num_tokens_per_cta = num_tiles_per_cta * kNumTokensPerTile;
339+
ceil_of_ratio(ceil_of_ratio(num_tokens, num_ctas), kNumTokensPerTile);
340+
const int num_tokens_per_cta = num_tiles_per_cta * kNumTokensPerTile;
331341

332342
Params<DataType, IndexType> params = {
333-
reinterpret_cast<DataType*>(scores.data_ptr()),
334-
static_cast<int>(scores.stride(0)),
335-
static_cast<int>(scores.stride(1)),
336-
num_tokens,
337-
num_tokens_per_cta,
338-
reinterpret_cast<IndexType*>(counts.data_ptr()),
339-
reinterpret_cast<IndexType*>(expert_indices.data_ptr()),
340-
reinterpret_cast<IndexType*>(token_indices.data_ptr()),
341-
reinterpret_cast<IndexType*>(shuffled_expert_indices.data_ptr()),
342-
reinterpret_cast<IndexType*>(shuffled_token_indices.data_ptr()),
343-
reinterpret_cast<IndexType*>(
344-
num_valid_tokens.has_value() ? num_valid_tokens->data_ptr()
345-
: nullptr)};
343+
// Inputs
344+
.routing_scores = reinterpret_cast<DataType*>(routing_scores.data_ptr()),
345+
.stride_t = static_cast<int>(routing_scores.stride(0)),
346+
.stride_e = static_cast<int>(routing_scores.stride(1)),
347+
.valid_token_count = get_ptr<IndexType>(valid_token_count),
348+
.num_tokens = num_tokens,
349+
.num_tokens_per_cta = num_tokens_per_cta,
350+
// Buffer
351+
.buffered_expert_indices =
352+
reinterpret_cast<IndexType*>(buffered_expert_indices.data_ptr()),
353+
.buffered_token_indices =
354+
reinterpret_cast<IndexType*>(buffered_token_indices.data_ptr()),
355+
// Outputs
356+
.token_count_per_expert =
357+
reinterpret_cast<IndexType*>(token_count_per_expert.data_ptr()),
358+
.shuffled_expert_indices =
359+
reinterpret_cast<IndexType*>(shuffled_expert_indices.data_ptr()),
360+
.shuffled_token_indices =
361+
reinterpret_cast<IndexType*>(shuffled_token_indices.data_ptr())};
346362

347363
dim3 grids(num_ctas);
348364
dim3 blocks(kNumThreads);
@@ -360,7 +376,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> index_shuffling_torch(
360376
#endif
361377

362378
return std::make_tuple(
363-
counts, shuffled_expert_indices, shuffled_token_indices);
379+
token_count_per_expert, shuffled_expert_indices, shuffled_token_indices);
364380
}
365381

366382
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)