Skip to content

[Kernel] Integrate CUTLASS MoE kernel with PPLX #18762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

ElizaWszola
Copy link
Contributor

@ElizaWszola ElizaWszola commented May 27, 2025

Integrate CUTLASS MoE fp8 kernels with PPLX.

Unit tests:

tests/kernels/moe/test_pplx_cutlass_moe.py

E2E testing:

export MASTER_ADDR=127.0.0.1
export MASTER_PORT=29500
export VLLM_ALL2ALL_BACKEND=pplx
python3 examples/offline_inference/data_parallel.py \
        --model="nm-testing/DeepSeek-Coder-V2-Lite-Instruct-FP8" \
        --dp-size=2 \
        --tp-size=1 \
        --trust-remote-code

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@ElizaWszola ElizaWszola changed the title [Kernel] Integrate CUTLASS MoE kernel with PPLX [WIP][Kernel] Integrate CUTLASS MoE kernel with PPLX May 27, 2025
@mergify mergify bot added the ci/build label May 28, 2025
@ElizaWszola ElizaWszola changed the title [WIP][Kernel] Integrate CUTLASS MoE kernel with PPLX [Kernel] Integrate CUTLASS MoE kernel with PPLX May 28, 2025
@ElizaWszola ElizaWszola marked this pull request as ready for review May 28, 2025 16:10
NUM_EXPERTS = [40, 64]
TOP_KS = [6, 8]

P = ParamSpec("P")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably put all these multiprocess utilities in a separate file now since they are also used by test_pplx_moe.py

(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
def test_cutlass_moe_pptx(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type pptx -> pplx

@@ -812,6 +847,7 @@ def __init__(
assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method
self.quant_method.moe = moe
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit sketchy to me. If the quant_method needs a MoEConfig it should be part of the constructor or passed around as an argument.

@@ -140,7 +153,7 @@ def finalize(
topk_weights = torch.ones_like(topk_weights)

self.a2a.combine(out_tokens=output,
indices=topk_ids,
indices=topk_ids.to(torch.uint32),
Copy link
Contributor

@bnellnm bnellnm May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto using indices_type

@@ -110,9 +121,11 @@ def prepare(
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=rank_topk_ids,
indices=rank_topk_ids.to(torch.uint32),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need this cast anymore. The type of the topk_ids can be controlled by passing torch.uint32 via indices_type to select_experts.

Comment on lines +127 to +128
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, 0:1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the same as expert_x_scale.view(-1, -1, 1)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is taking only one slice from expert_x_scale's last dim. this is related to the scale format required by dispatch().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on that? I copied the setup of expert_x_scales directly from pplx's test_all_to_all.py test so I assumed it would have the proper format already.

Comment on lines 64 to 76
if self.per_act_token:
repeat_rows = 1
a1q, a1q_scale = moe_kernel_quantize_input(a1, None,
self.quant_dtype,
self.per_act_token,
self.block_shape)
else:
repeat_rows = a1.shape[0]
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
self.quant_dtype,
self.per_act_token,
self.block_shape)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you collapse these branches together? e.g.

repeat_rows = 1 if self.per_act_token else a1.shape[0]
a1q, a1q_scale = moe_kernel_quantize_input(
                a1,
                a1_scale if not self.per_act_token else None,
                self.quant_dtype,
                self.per_act_token,
                self.block_shape)

or even simpler if a1_scale is None iff self.per_act_token you can just pass a1_scale directly.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments, but looks good overall -- lets try to get it landed once those and Bill's comments are addressed!

Comment on lines 111 to 112
int expert_idx_in = non_zero_expert_idxs[expert_idx_out];
expert_offsets[expert_idx_out] = expert_idx_in * padded_m;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int expert_idx_in = non_zero_expert_idxs[expert_idx_out];
expert_offsets[expert_idx_out] = expert_idx_in * padded_m;
int expert_idx_in = static_cast<int32_t>(non_zero_expert_idxs[expert_idx_out]);
expert_offsets[expert_idx_out] = expert_idx_in * padded_m;

Could you doublecheck and try to not add any warnings to the build? (The implicit down-conversion here looks safe enough to me, but best to avoid implicit conversions)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this producing any new warnings, I'll make the change

Comment on lines 261 to 262
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be the following?

Suggested change
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90)) \

I'm not sure why we're looking at ENABLE_SCALED_MM_SM90, but the check for ENABLE_SCALED_MM_SM100 definitely looks wrong

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy-paste issue, I didn't notice that the old MoE data kernel I copied it from has SM100 support for fp4 now

from vllm.platforms import current_platform

try:
from pplx_kernels import AllToAll # or AllToAllInternode?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, AllToAll dispatches to AllToAllInternode under the hood, so we shouldn't need to interact with it directly

Comment on lines 554 to 555
((self.moe.num_experts + prepare_finalize.world_size - 1) //
prepare_finalize.world_size), self.moe.in_dtype,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to factor out (self.moe.num_experts + prepare_finalize.world_size - 1) // prepare_finalize.world_size into a local variable for clarity, as it would help explain what it's doing.

Comment on lines 329 to 330
padded_M = (self.prepare_finalize.max_num_tokens *
self.prepare_finalize.world_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be max_num_tokens * (world_size // dp_size)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should, thanks! I'll be getting padded_M directly from a1q now though

Comment on lines 326 to 333
# compute padded_M for PplxPrepareAndFinalize
if (hasattr(self.prepare_finalize, 'max_num_tokens')
and hasattr(self.prepare_finalize, 'world_size')):
padded_M = (self.prepare_finalize.max_num_tokens *
self.prepare_finalize.world_size)
else:
padded_M = M

Copy link
Contributor

@bnellnm bnellnm May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic to get padded_M would be better as an abstract method on FusedMoePrepareAndFinalize that was specialized for pplx that returns M otherwise.

Another alternative might be to move the workspace allocation after the call to prepare. That way a1q would have the proper padded_M shape information. Although some of the existing workspace_shape methods might need some adjustment then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll go for the second suggestion

@@ -192,6 +195,7 @@ class MoEConfig:
moe_parallel_config: FusedMoEParallelConfig

in_dtype: torch.dtype # The activation type.
quant_dtype: torch.dtype = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in_dtype was intended to be the post quantization activation type (this wasn't quite clear though). I'm fine with adding another field as long as we still need both types. Otherwise, we should just keep one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disregard this comment. I think we'll need both of these types.


if expert_num_tokens is not None:
non_zero_mask = expert_num_tokens[:] != 0
masked_local_E = int(non_zero_mask.sum().item())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to interfere with cudagraphs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially... I think I can circumvent it with a custom CUDA kernel and extra mapping for expert_offsets and problem_sizes1/2 if needed

if expert_num_tokens is not None:
non_zero_mask = expert_num_tokens[:] != 0
masked_local_E = int(non_zero_mask.sum().item())
non_zero_expert_idxs = torch.nonzero(non_zero_mask).flatten()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think nonzero might cause trouble also.

Comment on lines +297 to +308
ws1 = a.size(0) * topk_ids.size(1) * max(w1_q.size(1), w2_q.size(1))
ws2 = a.size(0) * topk_ids.size(1) * w2_q.size(2)
workspace13 = torch.zeros(ws1, device=a.device, dtype=out_dtype)
workspace2 = torch.zeros(ws2, device=a.device, dtype=out_dtype)

if apply_router_weight_on_input:
assert topk_ids.shape[
1] == 1, "topk_ids must be 1 for apply_router_weight_on_input"
a = a * topk_weights.to(a.dtype)

from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
a1q, a1q_scale = _fp8_quantize(a, a1_scale, per_act_token)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this bit pulled out? It should be handled by whatever PrepareAndFinalze object is used w/CutlassExpertsFp8.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the function that runs the old version of cutlass MoE when no PrepareAndFinalize is being run. I changed the structure of functions/classes in this file, is it less messy now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking more in terms of not duplicating code and using the new modular classes to serve as the implementation of cutlass_moe_fp8. MoEPrepareAndFinalizeNoEP should be able to do all the preparation/finalization and doesn't do any communication.

Comment on lines 329 to 331
if not apply_router_weight_on_input:
out = out * topk_weights.view(topk_weights.shape[0],
topk_weights.shape[1], 1).to(out_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -92,7 +95,7 @@ def prepare(
else 1) * float32_size
expert_x_scale = torch.empty(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if you've got correctness yet but I had to use torch.zeros here to get some of my fp8 + pplx tests working. Probably due to the alignment/padding requirements of the scale bytes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see any issues with keeping empty here, but I'll try testing with more param configurations and see if some errors shows up

Comment on lines +297 to +298
input_activations = get_quant_config_input_activations(
quant_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this guaranteed to always exist? I've tried running Qwen/Qwen3-30B-A3B-FP8 and it didn't appear to have an input_activations field in its quant_config.

I think this is probably fine for now but maybe with a None check. There's a more general problem of finding the proper quantization info for each MoE layer that needs to be solved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants