Skip to content

[Kernel] Integrate CUTLASS MoE kernel with PPLX #18762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jun 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
888177a
Cutlass MoE pplx - working unit tests
ElizaWszola May 19, 2025
6486345
Working e2e, but there are some hacks and it needs cleaning
ElizaWszola May 20, 2025
5d4751f
Set the correct workspace shapes, padded and unpadded c1,c2,c3
ElizaWszola May 26, 2025
9cb7802
format
ElizaWszola May 26, 2025
f34d6b1
Merge branch 'main' into cutlass-moe-pplx-integration
ElizaWszola May 26, 2025
9499f74
uncomment quant_method selection
ElizaWszola May 26, 2025
df1a014
Working e2e after merge
ElizaWszola May 27, 2025
503a9b3
Nuke output map codepath, clean up a bit
ElizaWszola May 27, 2025
8c1d57b
Fix the non-pplx codepath
ElizaWszola May 28, 2025
268bbea
CUDA kernel for pplx data computation, cleanups, fixing unit tests an…
ElizaWszola May 28, 2025
a6236bf
Various cleanups
ElizaWszola May 28, 2025
ab46919
Better types and attribute checks
ElizaWszola May 28, 2025
3e436fd
Missing return, type check ignore
ElizaWszola May 28, 2025
787a660
Various feedback
ElizaWszola May 30, 2025
34d4410
Fix workspace_shapes in deep gemm moe
ElizaWszola May 30, 2025
918add7
Clean up cutlass_moe_fp8 function
ElizaWszola Jun 2, 2025
ff4d923
Merge branch 'main' into cutlass-moe-pplx-integration
ElizaWszola Jun 3, 2025
79c6a94
Make compatible with cudagraphs
ElizaWszola Jun 3, 2025
1cb0fb6
Merge branch 'main' into cutlass-moe-pplx-integration
ElizaWszola Jun 4, 2025
ec55bee
Comment about different codepaths, is_pplz variable
ElizaWszola Jun 5, 2025
f5b4ff0
Merge branch 'main' into cutlass-moe-pplx-integration
ElizaWszola Jun 5, 2025
f6e968d
fix workspace_shapes args
ElizaWszola Jun 5, 2025
24fa6cb
Merge branch 'main' into cutlass-moe-pplx-integration
tlrmchlsmth Jun 5, 2025
ba6ce8f
Merge branch 'main' into cutlass-moe-pplx-integration
ElizaWszola Jun 6, 2025
db5aa2d
Import fixes, headers
ElizaWszola Jun 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUTLASS MoE kernels

# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
# to compile MoE kernels that use its output.
# on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled
# if it's possible to compile MoE kernels that use its output.
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
Expand Down
61 changes: 11 additions & 50 deletions benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import (
cutlass_moe_fp8,
fused_experts,
fused_topk,
)
Expand Down Expand Up @@ -70,18 +70,9 @@ def bench_run(
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)

ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)

for expert in range(num_experts):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
w1_q_notransp = w1_q.clone()
w2_q_notransp = w2_q.clone()
w1_q = w1_q.transpose(1, 2)
w2_q = w2_q.transpose(1, 2)

score = torch.randn((m, num_experts), device="cuda", dtype=dtype)

Expand Down Expand Up @@ -122,25 +113,17 @@ def run_cutlass_moe(
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
num_repeats: int,
):
for _ in range(num_repeats):
cutlass_moe_fp8(
a,
w1,
w2,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
w1_scale,
w2_scale,
a1_scale=a_scale,
)

Expand All @@ -153,10 +136,6 @@ def run_cutlass_from_graph(
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
):
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
Expand All @@ -165,14 +144,10 @@ def run_cutlass_from_graph(
a,
w1_q,
w2_q,
w1_scale,
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
w1_scale,
w2_scale,
a1_scale=a_scale,
)

Expand Down Expand Up @@ -218,10 +193,6 @@ def replay_graph(graph, num_repeats):
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
)
torch.cuda.synchronize()

Expand All @@ -230,8 +201,8 @@ def replay_graph(graph, num_repeats):
with torch.cuda.graph(triton_graph, stream=triton_stream):
run_triton_from_graph(
a,
w1_q_notransp,
w2_q_notransp,
w1_q,
w2_q,
topk_weights,
topk_ids,
w1_scale,
Expand All @@ -250,18 +221,12 @@ def replay_graph(graph, num_repeats):
"w2": w2,
"score": score,
"topk": topk,
"w1_q_notransp": w1_q_notransp,
"w2_q_notransp": w2_q_notransp,
# Cutlass params
"a_scale": a_scale,
"w1_q": w1_q,
"w2_q": w2_q,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"ab_strides1": ab_strides1,
"c_strides1": c_strides1,
"ab_strides2": ab_strides2,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
Expand All @@ -279,8 +244,8 @@ def replay_graph(graph, num_repeats):
# Warmup
run_triton_moe(
a,
w1_q_notransp,
w2_q_notransp,
w1_q,
w2_q,
topk_weights,
topk_ids,
w1_scale,
Expand All @@ -291,7 +256,7 @@ def replay_graph(graph, num_repeats):

results.append(
benchmark.Timer(
stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down Expand Up @@ -322,16 +287,12 @@ def replay_graph(graph, num_repeats):
w2_scale,
topk_weights,
topk_ids,
ab_strides1,
c_strides1,
ab_strides2,
c_strides2,
num_warmup,
)

results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down
11 changes: 10 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ void cutlass_moe_mm(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);

void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
Expand All @@ -251,6 +252,14 @@ void get_cutlass_moe_mm_data(
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets);

void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k);

void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand Down
29 changes: 19 additions & 10 deletions csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
Expand Down Expand Up @@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90(
if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (k >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (m <= 16) {
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}

Expand All @@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}

Expand All @@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides);
c_strides, per_act_token, per_out_ch);
}
6 changes: 2 additions & 4 deletions csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,15 @@ void cutlass_group_gemm_caller(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;

int num_experts = static_cast<int>(expert_offsets.size(0));
int k_size = a_tensors.size(1);
int n_size = out_tensors.size(1);

bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;

auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());

auto options_int =
Expand Down
42 changes: 38 additions & 4 deletions csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

constexpr uint64_t THREADS_PER_EXPERT = 512;

__global__ void compute_problem_sizes(const int* __restrict__ topk_ids,
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
Expand Down Expand Up @@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
}
}

__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation,
int32_t* output_permutation,
Expand Down Expand Up @@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(

int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
Expand All @@ -120,10 +120,44 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
}
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
topk_ids.size(1));
}

__global__ void compute_pplx_data(int32_t* expert_offsets,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
const int32_t* __restrict__ expert_num_tokens,
const int padded_m, const int n,
const int k) {
int expert_idx = threadIdx.x;

expert_offsets[expert_idx] = expert_idx * padded_m;
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
problem_sizes1[expert_idx * 3 + 2] = k;
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes2[expert_idx * 3 + 1] = k;
problem_sizes2[expert_idx * 3 + 2] = n;
}

void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());

compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
}
Loading