-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[Kernel] Integrate CUTLASS MoE kernel with PPLX #18762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 13 commits
888177a
6486345
5d4751f
9cb7802
f34d6b1
9499f74
df1a014
503a9b3
8c1d57b
268bbea
a6236bf
ab46919
3e436fd
787a660
34d4410
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -36,7 +36,8 @@ void cutlass_moe_mm_sm90( | |||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, | ||||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, | ||||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, | ||||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides); | ||||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides, | ||||||||
bool per_act_token, bool per_out_ch); | ||||||||
|
||||||||
#endif | ||||||||
|
||||||||
|
@@ -55,6 +56,12 @@ void get_cutlass_moe_mm_data_caller( | |||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, | ||||||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation, | ||||||||
const int64_t num_experts, const int64_t n, const int64_t k); | ||||||||
|
||||||||
void get_cutlass_pplx_moe_mm_data_caller( | ||||||||
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, | ||||||||
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens, | ||||||||
const torch::Tensor& non_zero_expert_idxs, const int64_t num_local_experts, | ||||||||
const int64_t padded_m, const int64_t n, const int64_t k); | ||||||||
#endif | ||||||||
|
||||||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, | ||||||||
|
@@ -206,12 +213,13 @@ void cutlass_moe_mm( | |||||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales, | ||||||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, | ||||||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, | ||||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) { | ||||||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides, | ||||||||
bool per_act_token, bool per_out_ch) { | ||||||||
int32_t version_num = get_sm_version_num(); | ||||||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 | ||||||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, | ||||||||
expert_offsets, problem_sizes, a_strides, b_strides, | ||||||||
c_strides); | ||||||||
c_strides, per_act_token, per_out_ch); | ||||||||
return; | ||||||||
#endif | ||||||||
TORCH_CHECK_NOT_IMPLEMENTED( | ||||||||
|
@@ -242,6 +250,28 @@ void get_cutlass_moe_mm_data( | |||||||
version_num, ". Required capability: 90"); | ||||||||
} | ||||||||
|
||||||||
void get_cutlass_pplx_moe_mm_data( | ||||||||
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, | ||||||||
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens, | ||||||||
const torch::Tensor& non_zero_expert_idxs, const int64_t num_local_experts, | ||||||||
const int64_t padded_m, const int64_t n, const int64_t k) { | ||||||||
// This function currently gets compiled only if we have a valid cutlass moe | ||||||||
// mm to run it for. | ||||||||
int32_t version_num = get_sm_version_num(); | ||||||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ | ||||||||
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be the following?
Suggested change
I'm not sure why we're looking at There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copy-paste issue, I didn't notice that the old MoE data kernel I copied it from has SM100 support for fp4 now |
||||||||
get_cutlass_pplx_moe_mm_data_caller( | ||||||||
expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens, | ||||||||
non_zero_expert_idxs, num_local_experts, padded_m, n, k); | ||||||||
return; | ||||||||
#endif | ||||||||
TORCH_CHECK_NOT_IMPLEMENTED( | ||||||||
false, | ||||||||
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel " | ||||||||
"for CUDA device capability: ", | ||||||||
version_num, ". Required capability: 90"); | ||||||||
} | ||||||||
|
||||||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, | ||||||||
torch::Tensor const& b, | ||||||||
torch::Tensor const& a_scales, | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you doublecheck and try to not add any warnings to the build? (The implicit down-conversion here looks safe enough to me, but best to avoid implicit conversions)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see this producing any new warnings, I'll make the change