Skip to content

[Kernel] Integrate CUTLASS MoE kernel with PPLX #18762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUTLASS MoE kernels

# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
# to compile MoE kernels that use its output.
# on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
Expand Down
61 changes: 11 additions & 50 deletions benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import (
cutlass_moe_fp8,
fused_experts,
fused_topk,
)
Expand Down Expand Up @@ -69,18 +69,9 @@ def bench_run(
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)

ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)

for expert in range(num_experts):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
w1_q_notransp = w1_q.clone()
w2_q_notransp = w2_q.clone()
w1_q = w1_q.transpose(1, 2)
w2_q = w2_q.transpose(1, 2)

score = torch.randn((m, num_experts), device="cuda", dtype=dtype)

Expand Down Expand Up @@ -121,25 +112,17 @@ def run_cutlass_moe(
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
num_repeats: int,
):
for _ in range(num_repeats):
cutlass_moe_fp8(
a,
w1,
w2,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
w1_scale,
w2_scale,
a1_scale=a_scale,
)

Expand All @@ -152,10 +135,6 @@ def run_cutlass_from_graph(
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
Expand All @@ -164,14 +143,10 @@ def run_cutlass_from_graph(
a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
w1_scale,
w2_scale,
a1_scale=a_scale,
)

Expand Down Expand Up @@ -217,10 +192,6 @@ def replay_graph(graph, num_repeats):
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
)
torch.cuda.synchronize()

Expand All @@ -229,8 +200,8 @@ def replay_graph(graph, num_repeats):
with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph(
a,
w1_q_notransp,
w2_q_notransp,
w1_q,
w2_q,
topk_weights,
topk_ids,
w1_scale,
Expand All @@ -249,18 +220,12 @@ def replay_graph(graph, num_repeats):
"w2": w2,
"score": score,
"topk": topk,
"w1_q_notransp": w1_q_notransp,
"w2_q_notransp": w2_q_notransp,
# Cutlass params
"a_scale": a_scale,
"w1_q": w1_q,
"w2_q": w2_q,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"ab_strides1": ab_strides1,
"c_strides1": c_strides1,
"ab_strides2": ab_strides2,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
Expand All @@ -278,8 +243,8 @@ def replay_graph(graph, num_repeats):
# Warmup
run_triton_moe(
a,
w1_q_notransp,
w2_q_notransp,
w1_q,
w2_q,
topk_weights,
topk_ids,
w1_scale,
Expand All @@ -290,7 +255,7 @@ def replay_graph(graph, num_repeats):

results.append(
benchmark.Timer(
stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down Expand Up @@ -321,16 +286,12 @@ def replay_graph(graph, num_repeats):
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
num_warmup,
)

results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down
9 changes: 8 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ void cutlass_moe_mm(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);

void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
Expand All @@ -245,6 +246,12 @@ void get_cutlass_moe_mm_data(
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k);

void get_cutlass_pplx_moe_mm_data(
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
const torch::Tensor& non_zero_expert_idxs, const int64_t num_local_experts,
const int64_t padded_m, const int64_t n, const int64_t k);

void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand Down
29 changes: 19 additions & 10 deletions csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
Expand Down Expand Up @@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90(
if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (k >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (m <= 16) {
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}

Expand All @@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}

Expand All @@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides);
c_strides, per_act_token, per_out_ch);
}
6 changes: 2 additions & 4 deletions csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,15 @@ void cutlass_group_gemm_caller(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;

int num_experts = static_cast<int>(expert_offsets.size(0));
int k_size = a_tensors.size(1);
int n_size = out_tensors.size(1);

bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;

auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());

auto options_int =
Expand Down
32 changes: 32 additions & 0 deletions csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,35 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
topk_ids.size(1));
}

__global__ void compute_pplx_data(
int32_t* expert_offsets, int32_t* problem_sizes1, int32_t* problem_sizes2,
const int32_t* __restrict__ expert_num_tokens,
const int64_t* __restrict__ non_zero_expert_idxs, const int padded_m,
const int n, const int k) {
int expert_idx_out = threadIdx.x;
int expert_idx_in = non_zero_expert_idxs[expert_idx_out];
expert_offsets[expert_idx_out] = expert_idx_in * padded_m;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int expert_idx_in = non_zero_expert_idxs[expert_idx_out];
expert_offsets[expert_idx_out] = expert_idx_in * padded_m;
int expert_idx_in = static_cast<int32_t>(non_zero_expert_idxs[expert_idx_out]);
expert_offsets[expert_idx_out] = expert_idx_in * padded_m;

Could you doublecheck and try to not add any warnings to the build? (The implicit down-conversion here looks safe enough to me, but best to avoid implicit conversions)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this producing any new warnings, I'll make the change

problem_sizes1[expert_idx_out * 3] = expert_num_tokens[expert_idx_in];
problem_sizes1[expert_idx_out * 3 + 1] = 2 * n;
problem_sizes1[expert_idx_out * 3 + 2] = k;
problem_sizes2[expert_idx_out * 3] = expert_num_tokens[expert_idx_in];
problem_sizes2[expert_idx_out * 3 + 1] = k;
problem_sizes2[expert_idx_out * 3 + 2] = n;
}

void get_cutlass_pplx_moe_mm_data_caller(
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
const torch::Tensor& non_zero_expert_idxs, const int64_t num_local_experts,
const int64_t padded_m, const int64_t n, const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());

compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()),
static_cast<const int64_t*>(non_zero_expert_idxs.data_ptr()), padded_m, n,
k);
}
36 changes: 33 additions & 3 deletions csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);

#endif

Expand All @@ -55,6 +56,12 @@ void get_cutlass_moe_mm_data_caller(
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k);

void get_cutlass_pplx_moe_mm_data_caller(
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
const torch::Tensor& non_zero_expert_idxs, const int64_t num_local_experts,
const int64_t padded_m, const int64_t n, const int64_t k);
#endif

void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
Expand Down Expand Up @@ -206,12 +213,13 @@ void cutlass_moe_mm(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides);
c_strides, per_act_token, per_out_ch);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
Expand Down Expand Up @@ -242,6 +250,28 @@ void get_cutlass_moe_mm_data(
version_num, ". Required capability: 90");
}

void get_cutlass_pplx_moe_mm_data(
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
const torch::Tensor& non_zero_expert_idxs, const int64_t num_local_experts,
const int64_t padded_m, const int64_t n, const int64_t k) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be the following?

Suggested change
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90)) \

I'm not sure why we're looking at ENABLE_SCALED_MM_SM90, but the check for ENABLE_SCALED_MM_SM100 definitely looks wrong

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy-paste issue, I didn't notice that the old MoE data kernel I copied it from has SM100 support for fp4 now

get_cutlass_pplx_moe_mm_data_caller(
expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens,
non_zero_expert_idxs, num_local_experts, padded_m, n, k);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num, ". Required capability: 90");
}

void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand Down
Loading