-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Expand file tree
/
Copy pathmoeUtilOp.cpp
More file actions
447 lines (402 loc) · 26.6 KB
/
moeUtilOp.cpp
File metadata and controls
447 lines (402 loc) · 26.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/moeUtilOp.h"
#include "moe_gemm_kernels.h"
#include "tensorrt_llm/common/workspace.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h"
#include "tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/thop/thUtils.h"
#include <ATen/native/cuda/Resize.h>
namespace th = torch;
namespace tl = tensorrt_llm;
namespace tk = tensorrt_llm::kernels;
namespace common = tensorrt_llm::common;
namespace kernels = tensorrt_llm::kernels;
namespace cutlass_kernels = tensorrt_llm::kernels::cutlass_kernels;
namespace torch_ext
{
// input_activations: [num_tokens, hidden_size]
// input: token_topk_unpermuted_scales, [num_tokens, k]
// output: permuted_data_, [num_token * k, hidden_size]
// output: permuted_token_final_scales_, [num_tokens, k]
template <typename T>
void runPermute(void const* input_activations_void, void const* input_sf_void, int const* token_selected_experts,
float const* token_final_scales, void const* fc1_expert_weights_void, void const* fc1_expert_biases_void,
tensorrt_llm::ActivationType fc1_activation_type, void const* fc2_expert_weights_void,
void const* fc2_expert_biases_void, cutlass_kernels::QuantParams quant_params, int64_t const num_rows,
int64_t const hidden_size, int const full_num_experts, int const experts_per_token,
int* unpermuted_token_selected_experts_, int* unpermuted_source_token_ids_, int* permuted_source_token_ids_,
int* permuted_token_selected_experts_, T* permuted_data_, char* sorter_ws_, int64_t* expert_first_token_offset_,
float* permuted_token_final_scales_, int* expanded_source_row_to_expanded_dest_row,
cutlass_kernels::MOEParallelismConfig parallelism_config, cutlass_kernels::CubKeyValueSorter sorter_, bool use_lora,
kernels::LoraParams& lora_params, bool use_fp8_block_scaling, bool min_latency_mode,
cutlass_kernels::MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
{
TLLM_CHECK_WITH_INFO(experts_per_token * full_num_experts <= std::numeric_limits<int>::max(),
"experts_per_token * num_experts is too large");
auto const* input_activations = static_cast<T const*>(input_activations_void);
auto const* input_sf = input_sf_void
? reinterpret_cast<tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::ElementSF const*>(input_sf_void)
: nullptr;
int const num_experts_per_node = full_num_experts / parallelism_config.ep_size;
int start_expert = num_experts_per_node * parallelism_config.ep_rank;
int end_expert = start_expert + num_experts_per_node;
bool const needs_num_valid = parallelism_config.ep_size > 1;
// Note: expert_first_token_offset_[num_experts_per_node] stores the total number of expanded tokens
int64_t const* num_valid_tokens_ptr = needs_num_valid ? expert_first_token_offset_ + num_experts_per_node : nullptr;
bool use_w4afp8 = false;
bool fused_prologue_result = false;
if (!use_w4afp8)
{
// WAR: fusedBuildExpertMapsSortFirstToken kernel will lead to illegal memory access for W4AFP8
// input: token_selected_experts, [num_tokens, k]
// output: unpermuted_token_selected_experts_, [num_tokens, k]
// output: permuted_source_token_ids_, [num_tokens, k]
// output: expert_first_token_offset_, [num_experts_per_node + 1]
fused_prologue_result = kernels::fusedBuildExpertMapsSortFirstToken(token_selected_experts,
unpermuted_token_selected_experts_, permuted_source_token_ids_, expert_first_token_offset_, num_rows,
num_experts_per_node, experts_per_token, start_expert, end_expert, stream);
}
if (!fused_prologue_result)
{
TLLM_LOG_TRACE("Falling back to unfused prologue");
kernels::buildExpertMaps(token_selected_experts, unpermuted_token_selected_experts_,
unpermuted_source_token_ids_, num_rows, num_experts_per_node, experts_per_token, start_expert, end_expert,
stream);
sync_check_cuda_error(stream);
kernels::generateTokenPermutation(unpermuted_token_selected_experts_, unpermuted_source_token_ids_,
permuted_token_selected_experts_, permuted_source_token_ids_, expert_first_token_offset_, num_rows,
num_experts_per_node, experts_per_token, sorter_, static_cast<void*>(sorter_ws_), stream);
}
sync_check_cuda_error(stream);
// using ExpandedActivationsType = std::conditional_t<use_w4afp8, BackBoneType, T>;
using ExpandedActivationsType = T;
// input_activations: [num_tokens, hidden_size]
// output: permuted_data_, [num_token * k, hidden_size]
// input: token_topk_unpermuted_scales, [num_tokens, k]
// output: permuted_token_final_scales_, [num_tokens * k]
// input: permuted_source_token_ids_, [num_tokens, k]
// output: expanded_source_row_to_expanded_dest_row, [num_tokens, k]
float const* token_topk_unpermuted_scales = token_final_scales;
kernels::expandInputRowsKernelLauncher(input_activations,
reinterpret_cast<ExpandedActivationsType*>(permuted_data_), token_topk_unpermuted_scales,
permuted_token_final_scales_, permuted_source_token_ids_, expanded_source_row_to_expanded_dest_row, num_rows,
num_valid_tokens_ptr, hidden_size, experts_per_token, num_experts_per_node,
quant_params.fp4.fc1.act_global_scale, expert_first_token_offset_,
/* fc1_fp4_act_scale_ */ nullptr, input_sf, stream);
sync_check_cuda_error(stream);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
moe_permute_op(torch::Tensor const& input, torch::Tensor const& token_selected_experts,
torch::optional<torch::Tensor> token_final_scales, torch::Tensor const& fc1_expert_weights,
torch::Tensor const& fc2_expert_weights, torch::optional<c10::ArrayRef<torch::Tensor>> quant_scales,
torch::optional<torch::Tensor> input_sf, int64_t const num_experts_on_rank, int64_t const tp_size,
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
int64_t const cluster_rank, bool min_latency_mode, bool use_fp8_block_scaling)
{
cutlass_kernels::CubKeyValueSorter sorter_;
TORCH_CHECK(cluster_size == 1 && cluster_rank == 0, "smart_router is supported in min_latency mode");
TORCH_CHECK(min_latency_mode == false, "min_latency_mode is not supported now");
CHECK_INPUT(token_selected_experts, at::ScalarType::Int)
if (token_final_scales)
{
CHECK_INPUT(token_final_scales.value(), at::ScalarType::Float)
}
TORCH_CHECK(input.dim() == 2, "input must be 2D.");
TORCH_CHECK(token_selected_experts.dim() == 2, "token_selected_experts must be 2D.");
TORCH_CHECK(input.sizes()[0] == token_selected_experts.sizes()[0],
"input and token_selected_experts must have the same num tokens.");
if (token_final_scales)
{
TORCH_CHECK(token_final_scales.value().dim() == 2, "token_selected_experts_probs must be 2D.");
TORCH_CHECK(input.sizes()[0] == token_final_scales.value().sizes()[0],
"input and token_selected_experts_probs must have the same num tokens.");
TORCH_CHECK(token_selected_experts.sizes()[1] == token_final_scales.value().sizes()[1],
"token_selected_experts and token_final_scales must have the same number of experts per token.");
}
int experts_per_token = token_selected_experts.sizes()[1];
int64_t num_rows = input.sizes()[0];
int64_t hidden_size = input.sizes()[1];
auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size);
auto parallelism_config = cutlass_kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank);
auto activation_type = tensorrt_llm::ActivationType::Swiglu;
int const num_experts_per_node = num_experts_on_rank;
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
size_t num_moe_inputs = experts_per_token * num_rows;
auto unpermuted_token_selected_experts_tensor
= torch::empty({num_moe_inputs}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
auto unpermuted_source_token_ids_tensor
= torch::empty({num_moe_inputs}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
auto permuted_source_token_ids_tensor
= torch::empty({num_moe_inputs}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
auto permuted_token_selected_experts_tensor
= torch::empty({num_moe_inputs}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
auto permuted_data_tensor = torch::empty({num_moe_inputs, hidden_size}, input.options().requires_grad(false));
auto permuted_token_final_scales_tensor
= torch::empty({num_moe_inputs}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
auto expert_first_token_offset_tensor = torch::empty(
{num_experts_per_node + 1}, torch::dtype(torch::kInt64).device(torch::kCUDA).requires_grad(false));
size_t const sorter_size = min_latency_mode
? 0
: cutlass_kernels::CubKeyValueSorter::getWorkspaceSize(num_rows * experts_per_token, num_experts_per_node);
auto sorter_ws_tensor
= torch::empty({sorter_size}, torch::dtype(torch::kChar).device(torch::kCUDA).requires_grad(false));
auto src_to_dest_map_tensor = torch::empty(
{experts_per_token * num_rows}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
cutlass_kernels::QuantParams quant_params{};
cutlass_kernels::MoeMinLatencyParams min_latency_params{};
kernels::LoraParams lora_params{};
auto data_type = input.scalar_type();
switch (data_type)
{
case torch::kFloat32:
runPermute<float>(input.const_data_ptr(), input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr,
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
: nullptr,
/*fc1_expert_weights.const_data_ptr()*/ nullptr, nullptr, activation_type,
/*fc2_expert_weights.const_data_ptr()*/ nullptr, nullptr, quant_params, num_rows, hidden_size,
num_experts_total, static_cast<int>(experts_per_token),
static_cast<int*>(unpermuted_token_selected_experts_tensor.data_ptr()),
static_cast<int*>(unpermuted_source_token_ids_tensor.data_ptr()),
static_cast<int*>(permuted_source_token_ids_tensor.data_ptr()),
static_cast<int*>(permuted_token_selected_experts_tensor.data_ptr()),
static_cast<float*>(permuted_data_tensor.data_ptr()), static_cast<char*>(sorter_ws_tensor.data_ptr()),
static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()),
static_cast<float*>(permuted_token_final_scales_tensor.data_ptr()),
static_cast<int*>(src_to_dest_map_tensor.data_ptr()), parallelism_config, sorter_, false, lora_params,
use_fp8_block_scaling, min_latency_mode, min_latency_params, stream);
break;
case torch::kBFloat16:
runPermute<__nv_bfloat16>(input.const_data_ptr(),
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr,
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
: nullptr,
/*fc1_expert_weights.const_data_ptr()*/ nullptr, nullptr, activation_type,
/*fc2_expert_weights.const_data_ptr()*/ nullptr, nullptr, quant_params, num_rows, hidden_size,
num_experts_total, static_cast<int>(experts_per_token),
static_cast<int*>(unpermuted_token_selected_experts_tensor.data_ptr()),
static_cast<int*>(unpermuted_source_token_ids_tensor.data_ptr()),
static_cast<int*>(permuted_source_token_ids_tensor.data_ptr()),
static_cast<int*>(permuted_token_selected_experts_tensor.data_ptr()),
static_cast<__nv_bfloat16*>(permuted_data_tensor.data_ptr()),
static_cast<char*>(sorter_ws_tensor.data_ptr()),
static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()),
static_cast<float*>(permuted_token_final_scales_tensor.data_ptr()),
static_cast<int*>(src_to_dest_map_tensor.data_ptr()), parallelism_config, sorter_, false, lora_params,
use_fp8_block_scaling, min_latency_mode, min_latency_params, stream);
break;
case torch::kHalf:
runPermute<half>(input.const_data_ptr(), input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr,
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
: nullptr,
/*fc1_expert_weights.const_data_ptr()*/ nullptr, nullptr, activation_type,
/*fc2_expert_weights.const_data_ptr()*/ nullptr, nullptr, quant_params, num_rows, hidden_size,
num_experts_total, static_cast<int>(experts_per_token),
static_cast<int*>(unpermuted_token_selected_experts_tensor.data_ptr()),
static_cast<int*>(unpermuted_source_token_ids_tensor.data_ptr()),
static_cast<int*>(permuted_source_token_ids_tensor.data_ptr()),
static_cast<int*>(permuted_token_selected_experts_tensor.data_ptr()),
static_cast<half*>(permuted_data_tensor.data_ptr()), static_cast<char*>(sorter_ws_tensor.data_ptr()),
static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()),
static_cast<float*>(permuted_token_final_scales_tensor.data_ptr()),
static_cast<int*>(src_to_dest_map_tensor.data_ptr()), parallelism_config, sorter_, false, lora_params,
use_fp8_block_scaling, min_latency_mode, min_latency_params, stream);
break;
default:
throw std::invalid_argument(
"Invalid dtype, only supports intput tensor with float32, float16 and bfloat16 dtype");
break;
}
return std::make_tuple(unpermuted_token_selected_experts_tensor, unpermuted_source_token_ids_tensor,
permuted_source_token_ids_tensor, permuted_token_selected_experts_tensor, permuted_data_tensor,
expert_first_token_offset_tensor, permuted_token_final_scales_tensor, src_to_dest_map_tensor);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> run_moe_expand_op(torch::Tensor const& input,
torch::optional<torch::Tensor> token_final_scales, torch::Tensor const& permuted_source_token_ids,
int64_t const num_rows, torch::Tensor& expert_first_token_offset_tensor, int64_t const hidden_size,
int64_t const experts_per_token, int64_t const num_experts_per_node, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, bool use_fp8_block_scaling)
{
auto parallelism_config = cutlass_kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank);
bool const needs_num_valid = parallelism_config.ep_size > 1;
int64_t const* num_valid_tokens_ptr = needs_num_valid
? static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()) + num_experts_per_node
: nullptr;
size_t num_moe_inputs
= use_fp8_block_scaling ? (experts_per_token * num_rows + 3) / 4 * 4 : experts_per_token * num_rows;
auto permuted_data_tensor = torch::empty({num_moe_inputs, hidden_size}, input.options().requires_grad(false));
auto permuted_token_final_scales_tensor
= torch::empty({num_moe_inputs}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
auto expanded_source_row_to_expanded_dest_row
= torch::empty({num_moe_inputs}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
cutlass_kernels::QuantParams quant_params{};
float const* token_topk_unpermuted_scales = token_final_scales.has_value()
? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
: nullptr;
auto data_type = input.scalar_type();
switch (data_type)
{
case torch::kFloat32:
kernels::expandInputRowsKernelLauncher<float, float>(static_cast<float const*>(input.const_data_ptr()),
reinterpret_cast<float*>(permuted_data_tensor.data_ptr()), token_topk_unpermuted_scales,
static_cast<float*>(permuted_token_final_scales_tensor.data_ptr()),
static_cast<int const*>(permuted_source_token_ids.const_data_ptr()),
static_cast<int*>(expanded_source_row_to_expanded_dest_row.data_ptr()), num_rows, num_valid_tokens_ptr,
hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale,
static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()),
/* fc1_fp4_act_scale_ */ nullptr, /*input_sf*/ nullptr, stream);
break;
case torch::kBFloat16:
kernels::expandInputRowsKernelLauncher<__nv_bfloat16, __nv_bfloat16>(
static_cast<__nv_bfloat16 const*>(input.const_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(permuted_data_tensor.data_ptr()), token_topk_unpermuted_scales,
static_cast<float*>(permuted_token_final_scales_tensor.data_ptr()),
static_cast<int const*>(permuted_source_token_ids.const_data_ptr()),
static_cast<int*>(expanded_source_row_to_expanded_dest_row.data_ptr()), num_rows, num_valid_tokens_ptr,
hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale,
static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()),
/* fc1_fp4_act_scale_ */ nullptr, /*input_sf*/ nullptr, stream);
break;
case torch::kHalf:
kernels::expandInputRowsKernelLauncher<half, half>(static_cast<half const*>(input.const_data_ptr()),
reinterpret_cast<half*>(permuted_data_tensor.data_ptr()), token_topk_unpermuted_scales,
static_cast<float*>(permuted_token_final_scales_tensor.data_ptr()),
static_cast<int const*>(permuted_source_token_ids.const_data_ptr()),
static_cast<int*>(expanded_source_row_to_expanded_dest_row.data_ptr()), num_rows, num_valid_tokens_ptr,
hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale,
static_cast<int64_t*>(expert_first_token_offset_tensor.data_ptr()),
/* fc1_fp4_act_scale_ */ nullptr, /*input_sf*/ nullptr, stream);
break;
default:
throw std::invalid_argument(
"Invalid dtype, only supports intput tensor with float32, float16 and bfloat16 dtype");
break;
}
return std::make_tuple(
permuted_data_tensor, permuted_token_final_scales_tensor, expanded_source_row_to_expanded_dest_row);
}
template <class UnfusedGemmOutputType, class ScaleBiasType, class OutputType>
void runMoEFinalizeScaleOp(UnfusedGemmOutputType const* const gemm2_output,
ScaleBiasType const* const fc2_expert_biases, float const* const unpermuted_final_scales,
int const* const expanded_source_row_to_expanded_dest_row, int const* const expert_for_source_row,
int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, /*int64_t const expanded_num_rows,*/
int64_t const hidden_size, /*int64_t const inter_size, int const num_experts_per_node,*/
int64_t const experts_per_token, cutlass_kernels::MOEParallelismConfig parallelism_config, cudaStream_t stream,
OutputType* const final_output)
{
kernels::finalizeMoeRoutingKernelLauncher<OutputType, UnfusedGemmOutputType>(
static_cast<UnfusedGemmOutputType const*>(gemm2_output), final_output, fc2_expert_biases,
unpermuted_final_scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, num_rows, hidden_size,
experts_per_token, num_valid_tokens_ptr, parallelism_config, stream);
}
torch::Tensor run_moe_finalize_scale_op(torch::Tensor const& gemm2_output, torch::Tensor const& fc2_expert_biases,
torch::Tensor const& unpermuted_final_scales, torch::Tensor const& expanded_source_row_to_expanded_dest_row,
torch::Tensor const& expert_for_source_row, torch::Tensor const& expert_first_token_offset_tensor,
int64_t const num_rows, int64_t const hidden_size, int64_t const experts_per_token,
int64_t const num_experts_per_node, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size,
int64_t const ep_rank)
{
TORCH_CHECK(gemm2_output.dim() == 2, "gemm2_output must be 2D.");
TORCH_CHECK(unpermuted_final_scales.dim() == 2, "unpermuted_final_scales must be 2D.");
TORCH_CHECK(
expanded_source_row_to_expanded_dest_row.dim() == 1, "expanded_source_row_to_expanded_dest_row must be 1D.");
TORCH_CHECK(expert_for_source_row.dim() == 1, "expert_for_source_row must be 1D.");
TORCH_CHECK(expert_first_token_offset_tensor.dim() == 1, "expert_first_token_offset_tensor must be 1D.");
TORCH_CHECK(gemm2_output.sizes()[0] == expert_for_source_row.sizes()[0],
"gemm2_output and expert_for_source_row must have the same expanded num tokens.");
TORCH_CHECK(unpermuted_final_scales.sizes()[0] == num_rows, "unpermuted_final_scales[0] should equal to num_rows.");
TORCH_CHECK(unpermuted_final_scales.sizes()[1] == experts_per_token,
"unpermuted_final_scales[1] should equal to experts_per_token.");
TORCH_CHECK(expert_for_source_row.sizes()[0] == gemm2_output.sizes()[0],
"expert_for_source_row and gemm2_output must have the same expanded num tokens.");
TORCH_CHECK(expert_first_token_offset_tensor.sizes()[0] == num_experts_per_node + 1,
"expert_first_token_offset_tensor[0] should equal to num_experts_per_node + 1.");
auto parallelism_config = cutlass_kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank);
bool const needs_num_valid = parallelism_config.ep_size > 1;
int64_t const* num_valid_tokens_ptr = needs_num_valid
? static_cast<int64_t const*>(expert_first_token_offset_tensor.const_data_ptr()) + num_experts_per_node
: nullptr;
auto final_output = torch::empty({num_rows, hidden_size}, gemm2_output.options());
auto stream = at::cuda::getCurrentCUDAStream(gemm2_output.get_device());
auto data_type = gemm2_output.scalar_type();
switch (data_type)
{
case torch::kFloat32:
runMoEFinalizeScaleOp<float, float, float>(static_cast<float const*>(gemm2_output.const_data_ptr()),
// static_cast<float const*>(fc2_expert_biases.const_data_ptr()),
nullptr, static_cast<float const*>(unpermuted_final_scales.const_data_ptr()),
static_cast<int const*>(expanded_source_row_to_expanded_dest_row.const_data_ptr()),
static_cast<int const*>(expert_for_source_row.const_data_ptr()), num_valid_tokens_ptr, num_rows,
hidden_size, experts_per_token, parallelism_config, stream, static_cast<float*>(final_output.data_ptr()));
break;
case torch::kBFloat16:
runMoEFinalizeScaleOp<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>(
static_cast<__nv_bfloat16 const*>(gemm2_output.const_data_ptr()),
// static_cast<__nv_bfloat16 const*>(fc2_expert_biases.const_data_ptr()),
nullptr, static_cast<float const*>(unpermuted_final_scales.const_data_ptr()),
static_cast<int const*>(expanded_source_row_to_expanded_dest_row.const_data_ptr()),
static_cast<int const*>(expert_for_source_row.const_data_ptr()), num_valid_tokens_ptr, num_rows,
hidden_size, experts_per_token, parallelism_config, stream,
static_cast<__nv_bfloat16*>(final_output.data_ptr()));
break;
case torch::kHalf:
runMoEFinalizeScaleOp<half, half, half>(static_cast<half const*>(gemm2_output.const_data_ptr()),
// static_cast<half const*>(fc2_expert_biases.const_data_ptr()),
nullptr, static_cast<float const*>(unpermuted_final_scales.const_data_ptr()),
static_cast<int const*>(expanded_source_row_to_expanded_dest_row.const_data_ptr()),
static_cast<int const*>(expert_for_source_row.const_data_ptr()), num_valid_tokens_ptr, num_rows,
hidden_size, experts_per_token, parallelism_config, stream, static_cast<half*>(final_output.data_ptr()));
break;
default:
throw std::invalid_argument(
"Invalid dtype, only supports intput tensor with float32, float16 and bfloat16 dtype");
break;
}
return final_output;
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"moe_permute_op(Tensor input, Tensor token_selected_experts, Tensor? token_final_scales, Tensor "
"fc1_expert_weights, Tensor fc2_expert_weights, Tensor[]? quant_scales, Tensor? input_sf, int "
"num_experts_on_rank, int tp_size, int tp_rank, int ep_size, int ep_rank, int cluster_size, int cluster_rank, "
"bool min_latency_mode, bool use_fp8_block_scaling)"
"-> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def(
"moe_finalize_scale_op(Tensor gemm2_output, Tensor fc2_expert_biases, Tensor unpermuted_final_scales, Tensor "
"expanded_source_row_to_expanded_dest_row, Tensor expert_for_source_row, Tensor "
"expert_first_token_offset_tensor, int num_rows, int hidden_size, int experts_per_token, int "
"num_experts_per_node, int tp_size, int tp_rank, int ep_size, int ep_rank)"
"-> (Tensor)");
m.def(
"moe_expand_op(Tensor input, Tensor? token_final_scales, Tensor permuted_source_token_ids, int num_rows, "
"Tensor expert_first_token_offset_tensor, int hidden_size, int experts_per_token, int num_experts_per_node, "
"int tp_size, int tp_rank, int ep_size, int ep_rank, bool use_fp8_block_scaling)"
"-> (Tensor, Tensor, Tensor)");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("moe_permute_op", &torch_ext::moe_permute_op);
m.impl("moe_finalize_scale_op", &torch_ext::run_moe_finalize_scale_op);
m.impl("moe_expand_op", &torch_ext::run_moe_expand_op);
}