-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[Kernel] Integrate CUTLASS MoE kernel with PPLX #18762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Kernel] Integrate CUTLASS MoE kernel with PPLX #18762
Conversation
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
…d benchmarks Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
NUM_EXPERTS = [40, 64] | ||
TOP_KS = [6, 8] | ||
|
||
P = ParamSpec("P") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably put all these multiprocess utilities in a separate file now since they are also used by test_pplx_moe.py
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( | ||
current_platform.get_device_capability()), | ||
reason="Grouped gemm is not supported on this GPU type.") | ||
def test_cutlass_moe_pptx( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type pptx -> pplx
@@ -812,6 +847,7 @@ def __init__( | |||
assert quant_method is not None | |||
assert isinstance(quant_method, FusedMoEMethodBase) | |||
self.quant_method = quant_method | |||
self.quant_method.moe = moe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems a bit sketchy to me. If the quant_method
needs a MoEConfig
it should be part of the constructor or passed around as an argument.
@@ -140,7 +153,7 @@ def finalize( | |||
topk_weights = torch.ones_like(topk_weights) | |||
|
|||
self.a2a.combine(out_tokens=output, | |||
indices=topk_ids, | |||
indices=topk_ids.to(torch.uint32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto using indices_type
@@ -110,9 +121,11 @@ def prepare( | |||
out_expert_x_scale=expert_x_scale, | |||
dp_x=a1q, | |||
dp_x_scale=a1q_scale, | |||
indices=rank_topk_ids, | |||
indices=rank_topk_ids.to(torch.uint32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need this cast anymore. The type of the topk_ids can be controlled by passing torch.uint32 via indices_type to select_experts.
if expert_x_scale is not None: | ||
expert_x_scale = expert_x_scale[:, :, 0:1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the same as expert_x_scale.view(-1, -1, 1)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this is taking only one slice from expert_x_scale
's last dim. this is related to the scale format required by dispatch()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on that? I copied the setup of expert_x_scales
directly from pplx's test_all_to_all.py
test so I assumed it would have the proper format already.
if self.per_act_token: | ||
repeat_rows = 1 | ||
a1q, a1q_scale = moe_kernel_quantize_input(a1, None, | ||
self.quant_dtype, | ||
self.per_act_token, | ||
self.block_shape) | ||
else: | ||
repeat_rows = a1.shape[0] | ||
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, | ||
self.quant_dtype, | ||
self.per_act_token, | ||
self.block_shape) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you collapse these branches together? e.g.
repeat_rows = 1 if self.per_act_token else a1.shape[0]
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
a1_scale if not self.per_act_token else None,
self.quant_dtype,
self.per_act_token,
self.block_shape)
or even simpler if a1_scale
is None iff self.per_act_token
you can just pass a1_scale
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments, but looks good overall -- lets try to get it landed once those and Bill's comments are addressed!
int expert_idx_in = non_zero_expert_idxs[expert_idx_out]; | ||
expert_offsets[expert_idx_out] = expert_idx_in * padded_m; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int expert_idx_in = non_zero_expert_idxs[expert_idx_out]; | |
expert_offsets[expert_idx_out] = expert_idx_in * padded_m; | |
int expert_idx_in = static_cast<int32_t>(non_zero_expert_idxs[expert_idx_out]); | |
expert_offsets[expert_idx_out] = expert_idx_in * padded_m; |
Could you doublecheck and try to not add any warnings to the build? (The implicit down-conversion here looks safe enough to me, but best to avoid implicit conversions)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see this producing any new warnings, I'll make the change
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ | ||
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be the following?
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ | |
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90) | |
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90)) \ |
I'm not sure why we're looking at ENABLE_SCALED_MM_SM90
, but the check for ENABLE_SCALED_MM_SM100
definitely looks wrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy-paste issue, I didn't notice that the old MoE data kernel I copied it from has SM100 support for fp4 now
from vllm.platforms import current_platform | ||
|
||
try: | ||
from pplx_kernels import AllToAll # or AllToAllInternode? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, AllToAll
dispatches to AllToAllInternode
under the hood, so we shouldn't need to interact with it directly
((self.moe.num_experts + prepare_finalize.world_size - 1) // | ||
prepare_finalize.world_size), self.moe.in_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to factor out (self.moe.num_experts + prepare_finalize.world_size - 1) // prepare_finalize.world_size
into a local variable for clarity, as it would help explain what it's doing.
padded_M = (self.prepare_finalize.max_num_tokens * | ||
self.prepare_finalize.world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be max_num_tokens * (world_size // dp_size)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should, thanks! I'll be getting padded_M
directly from a1q
now though
# compute padded_M for PplxPrepareAndFinalize | ||
if (hasattr(self.prepare_finalize, 'max_num_tokens') | ||
and hasattr(self.prepare_finalize, 'world_size')): | ||
padded_M = (self.prepare_finalize.max_num_tokens * | ||
self.prepare_finalize.world_size) | ||
else: | ||
padded_M = M | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the logic to get padded_M
would be better as an abstract method on FusedMoePrepareAndFinalize
that was specialized for pplx that returns M
otherwise.
Another alternative might be to move the workspace allocation after the call to prepare. That way a1q
would have the proper padded_M shape information. Although some of the existing workspace_shape
methods might need some adjustment then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll go for the second suggestion
@@ -192,6 +195,7 @@ class MoEConfig: | |||
moe_parallel_config: FusedMoEParallelConfig | |||
|
|||
in_dtype: torch.dtype # The activation type. | |||
quant_dtype: torch.dtype = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in_dtype
was intended to be the post quantization activation type (this wasn't quite clear though). I'm fine with adding another field as long as we still need both types. Otherwise, we should just keep one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disregard this comment. I think we'll need both of these types.
|
||
if expert_num_tokens is not None: | ||
non_zero_mask = expert_num_tokens[:] != 0 | ||
masked_local_E = int(non_zero_mask.sum().item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this going to interfere with cudagraphs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
potentially... I think I can circumvent it with a custom CUDA kernel and extra mapping for expert_offsets
and problem_sizes1/2
if needed
if expert_num_tokens is not None: | ||
non_zero_mask = expert_num_tokens[:] != 0 | ||
masked_local_E = int(non_zero_mask.sum().item()) | ||
non_zero_expert_idxs = torch.nonzero(non_zero_mask).flatten() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think nonzero
might cause trouble also.
ws1 = a.size(0) * topk_ids.size(1) * max(w1_q.size(1), w2_q.size(1)) | ||
ws2 = a.size(0) * topk_ids.size(1) * w2_q.size(2) | ||
workspace13 = torch.zeros(ws1, device=a.device, dtype=out_dtype) | ||
workspace2 = torch.zeros(ws2, device=a.device, dtype=out_dtype) | ||
|
||
if apply_router_weight_on_input: | ||
assert topk_ids.shape[ | ||
1] == 1, "topk_ids must be 1 for apply_router_weight_on_input" | ||
a = a * topk_weights.to(a.dtype) | ||
|
||
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize | ||
a1q, a1q_scale = _fp8_quantize(a, a1_scale, per_act_token) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this bit pulled out? It should be handled by whatever PrepareAndFinalze
object is used w/CutlassExpertsFp8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be the function that runs the old version of cutlass MoE when no PrepareAndFinalize
is being run. I changed the structure of functions/classes in this file, is it less messy now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking more in terms of not duplicating code and using the new modular classes to serve as the implementation of cutlass_moe_fp8
. MoEPrepareAndFinalizeNoEP
should be able to do all the preparation/finalization and doesn't do any communication.
if not apply_router_weight_on_input: | ||
out = out * topk_weights.view(topk_weights.shape[0], | ||
topk_weights.shape[1], 1).to(out_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
@@ -92,7 +95,7 @@ def prepare( | |||
else 1) * float32_size | |||
expert_x_scale = torch.empty( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if you've got correctness yet but I had to use torch.zeros here to get some of my fp8 + pplx tests working. Probably due to the alignment/padding requirements of the scale bytes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't see any issues with keeping empty
here, but I'll try testing with more param configurations and see if some errors shows up
input_activations = get_quant_config_input_activations( | ||
quant_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this guaranteed to always exist? I've tried running Qwen/Qwen3-30B-A3B-FP8
and it didn't appear to have an input_activations
field in its quant_config
.
I think this is probably fine for now but maybe with a None
check. There's a more general problem of finding the proper quantization info for each MoE layer that needs to be solved.
Integrate CUTLASS MoE fp8 kernels with PPLX.
Unit tests:
E2E testing: