Skip to content

Commit 309e0b8

Browse files
rootwenscarl
authored andcommitted
wip
able to run ok
1 parent ec6acfd commit 309e0b8

File tree

7 files changed

+139
-65
lines changed

7 files changed

+139
-65
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,7 @@ endif()
10071007
# For CUDA we also build and ship some external projects.
10081008
if (VLLM_GPU_LANG STREQUAL "CUDA")
10091009
include(cmake/external_projects/flashmla.cmake)
1010-
include(cmake/external_projects/qutlass.cmake)
1010+
# include(cmake/external_projects/qutlass.cmake)
10111011

10121012
# vllm-flash-attn should be last as it overwrites some CMake functions
10131013
include(cmake/external_projects/vllm_flash_attn.cmake)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def _read_requirements(filename: str) -> list[str]:
636636
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
637637
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
638638
# FA3 requires CUDA 12.3 or later
639-
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
639+
# ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
640640
# Optional since this doesn't get built (produce an .so file) when
641641
# not targeting a hopper system
642642
ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True))

tests/kernels/moe/test_cutedsl_moe.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
from vllm.model_executor.layers.activation import SiluAndMul
1010
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
1111
flashinfer_cutedsl_moe_masked,
12-
scaled_fp4_grouped_quant,
1312
)
1413
from vllm.utils.flashinfer import (
1514
flashinfer_cutedsl_grouped_gemm_nt_masked as cutedsl_gmm_masked,
15+
scaled_fp4_grouped_quantize,
16+
silu_and_mul_scaled_nvfp4_experts_quantize,
1617
)
1718

1819
if torch.cuda.get_device_capability() < (10, 0):
@@ -219,16 +220,16 @@ def flashinfer_cutedsl_grouped_gemm_nt_masked(
219220
):
220221
# hidden_states: [l, m, k]
221222
# weights: [l, n, k]
222-
aq, aq_sf = scaled_fp4_grouped_quant(
223+
aq, aq_sf = scaled_fp4_grouped_quantize(
223224
hidden_states,
224-
input_global_scale,
225225
masked_m.to(hidden_states.device),
226+
input_global_scale,
226227
)
227228
num_experts, n, k = weights.shape
228-
bq, bq_sf = scaled_fp4_grouped_quant(
229+
bq, bq_sf = scaled_fp4_grouped_quantize(
229230
weights,
230-
w_global_scale,
231231
torch.full((num_experts,), n, device=weights.device, dtype=torch.int32),
232+
w_global_scale,
232233
)
233234

234235
out = torch.zeros(
@@ -316,15 +317,15 @@ def test_flashinfer_cutedsl_moe_masked(
316317
(num_experts,), dtype=torch.float32, device=hidden_states.device
317318
) # assume intermediate scale is 1.0
318319

319-
w1_fp4, w1_blockscale = scaled_fp4_grouped_quant(
320+
w1_fp4, w1_blockscale = scaled_fp4_grouped_quantize(
320321
w1,
322+
torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim,
321323
w1_global_scale,
322-
torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim,
323324
)
324-
w2_fp4, w2_blockscale = scaled_fp4_grouped_quant(
325+
w2_fp4, w2_blockscale = scaled_fp4_grouped_quantize(
325326
w2,
327+
torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim,
326328
w2_global_scale,
327-
torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim,
328329
)
329330

330331
w1_alpha = 1.0 / (input_global_scale * w1_global_scale)

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -111,33 +111,46 @@ def _do_quant(
111111
x_fp8, x_scales = x
112112
x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)
113113

114-
assert isinstance(x, torch.Tensor)
115-
116-
num_experts, max_tokens, hidden_dim = x.size()
117-
118-
# TODO (varun): Optimization - Use a batched version of quant
119-
x = x.view((-1, hidden_dim))
114+
assert isinstance(x, (torch.Tensor, tuple))
120115
q_dtype = quant_config.quant_dtype
121116

122-
if envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl":
123-
logger.info_once(
124-
"Skip quantization when using FlashInfer CUTEDSL for "
125-
"ModelOptNvFp4FusedMoE."
117+
if q_dtype == "nvfp4":
118+
assert isinstance(x, tuple)
119+
# num_experts, max_tokens, hidden_dim_by_4 = x[0].size()
120+
# print(f"nvfp4 quantization input shape: {x[0].size()}, {x[0].dtype}, {x[0].is_contiguous()}")
121+
# print(f"nvfp4 quantization input shape: {x[1].size()}, {x[1].dtype}, {x[1].is_contiguous()}")
122+
# print("after permute")
123+
x_scales = x[1]
124+
x = x[0].permute(2, 0, 1)
125+
# print(f"nvfp4 quantization input shape: {x.size()}, {x.dtype}, {x.is_contiguous()}")
126+
num_experts, max_tokens, hidden_dim_by_2 = x.shape
127+
hidden_dim = hidden_dim_by_2 * 2
128+
assert(envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl")
129+
logger.info_once("skip nvfp4 quant since done by deepep!!")
130+
# logger.info_once(
131+
# "Skip quantization when using FlashInfer CUTEDSL for "
132+
# "ModelOptNvFp4FusedMoE."
133+
# )
134+
else:
135+
assert isinstance(x, torch.Tensor)
136+
num_experts, max_tokens, hidden_dim = x.size()
137+
138+
# TODO (varun): Optimization - Use a batched version of quant
139+
x = x.view((-1, hidden_dim))
140+
x, x_scales = moe_kernel_quantize_input(
141+
x,
142+
quant_config.a1_scale,
143+
q_dtype,
144+
quant_config.per_act_token_quant,
145+
quant_config.block_shape,
126146
)
127-
q_dtype = None
128-
129-
x, x_scales = moe_kernel_quantize_input(
130-
x,
131-
quant_config.a1_scale,
132-
q_dtype,
133-
quant_config.per_act_token_quant,
134-
quant_config.block_shape,
135-
)
136-
x = x.view((num_experts, -1, hidden_dim))
147+
x = x.view((num_experts, -1, hidden_dim))
148+
137149

138150
if q_dtype is not None:
139151
assert x_scales is not None
140-
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
152+
if q_dtype != "nvfp4":
153+
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
141154

142155
return x, x_scales
143156

@@ -167,18 +180,26 @@ def prepare_async(
167180
"DeepEP kernels quantize the inputs in blocks of shape 128"
168181
)
169182

183+
use_nvfp4 = False
184+
# print("mm"*100, quant_config.quant_dtype)
185+
if quant_config.quant_dtype == "nvfp4":
186+
# print("gg"*100)
187+
# print(quant_config.a1_gscale)
188+
use_nvfp4 = True
189+
qc_a1_gscale_or_scale = quant_config.a1_gscale if quant_config.quant_dtype == "nvfp4" else quant_config.a1_scale
170190
has_per_token_scales = (
171-
quant_config.a1_scale.numel() != 1
172-
if quant_config.a1_scale is not None
191+
qc_a1_gscale_or_scale.numel() != 1
192+
if qc_a1_gscale_or_scale is not None
173193
else (
174194
quant_config.a2_scale.numel() != 1
175195
if quant_config.a2_scale is not None
176196
else False
177197
)
178198
)
179-
assert not has_per_token_scales, (
180-
"low_latency kernels doesn't support dispatching per-token scales"
181-
)
199+
if not use_nvfp4:
200+
assert not has_per_token_scales, (
201+
"low_latency kernels doesn't support dispatching per-token scales"
202+
)
182203

183204
if apply_router_weight_on_input:
184205
topk = topk_ids.size(1)
@@ -189,12 +210,19 @@ def prepare_async(
189210
a1 = a1 * topk_weights.to(a1.dtype)
190211

191212
# Dispatch
213+
# print("qwerwqrq"*100, use_nvfp4, qc_a1_gscale_or_scale.shape, a1.shape, a1.dtype)
192214
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
193215
a1,
194216
topk_ids,
195217
self.max_tokens_per_rank,
196218
num_experts,
197219
use_fp8=self.use_fp8_dispatch,
220+
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
221+
**(
222+
dict(x_global_scale=qc_a1_gscale_or_scale)
223+
if qc_a1_gscale_or_scale is not None
224+
else dict()
225+
),
198226
async_finish=False,
199227
return_recv_hook=True,
200228
)
@@ -220,7 +248,7 @@ def _receiver(
220248
quant_config: FusedMoEQuantConfig,
221249
) -> mk.PrepareResultType:
222250
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config)
223-
251+
224252
expert_tokens_meta = mk.ExpertTokensMetadata(
225253
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
226254
)
@@ -275,6 +303,8 @@ def _finalize(
275303

276304
# TODO (varun) : Enable zero copy mode
277305
dbo_maybe_run_recv_hook()
306+
# print("combine"*100)
307+
# print(fused_expert_output.shape, fused_expert_output.dtype)
278308
_, _, recv_hook = self.buffer.low_latency_combine(
279309
fused_expert_output,
280310
topk_ids,

vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
3+
from typing import Union
44
import torch
55

66
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -18,6 +18,7 @@
1818

1919
logger = init_logger(__name__)
2020

21+
CUTEDSL_MOE_NVFP4_DISPATCH = True
2122

2223
def is_valid_flashinfer_cutedsl_fused_moe(
2324
hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor
@@ -109,7 +110,12 @@ def workspace_shapes(
109110
- Note: in order for activation chunking to work, the first dimension
110111
of each tuple must be the number of tokens.
111112
"""
112-
output_shape = (local_num_experts, M, K)
113+
if CUTEDSL_MOE_NVFP4_DISPATCH:
114+
# since it sees quantized a1q
115+
K_dim = K * 2
116+
else:
117+
K_dim = K
118+
output_shape = (local_num_experts, M, K_dim)
113119
workspace2 = (local_num_experts, M, N)
114120
workspace1 = output_shape
115121
return (workspace1, workspace2, output_shape)
@@ -145,9 +151,10 @@ def apply(
145151
assert self.w1_scale.ndim == 3
146152
assert self.w2_scale.ndim == 3
147153

154+
# TODO(shuw): replace True by CUTEDSL_MOE_NVFP4_DISPATCH
148155
flashinfer_cutedsl_moe_masked(
149-
hidden_states=hidden_states,
150-
input_global_scale=self.a1_gscale,
156+
hidden_states=(hidden_states, a1q_scale) if CUTEDSL_MOE_NVFP4_DISPATCH else hidden_states,
157+
input_global_scale=(None if CUTEDSL_MOE_NVFP4_DISPATCH else self.a1_gscale),
151158
w1=w1,
152159
w1_blockscale=self.w1_scale,
153160
w1_alpha=self.g1_alphas,
@@ -173,7 +180,7 @@ def get_cute_dtype(input: torch.Tensor) -> str:
173180

174181

175182
def flashinfer_cutedsl_moe_masked(
176-
hidden_states: torch.Tensor,
183+
hidden_states: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
177184
input_global_scale: torch.Tensor,
178185
w1: torch.Tensor,
179186
w1_blockscale: torch.Tensor,
@@ -191,7 +198,9 @@ def flashinfer_cutedsl_moe_masked(
191198
kernels.
192199
193200
Args:
194-
hidden_states (torch.Tensor): [num_experts, m, k], bf16
201+
hidden_states: Either of the following case
202+
* torch.Tensor: [num_experts, m, k], bf16
203+
* tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], uint8, [num_experts, m, k // 16], float8_e4m3fn
195204
input_global_scale (torch.Tensor): (l,)
196205
w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8
197206
w1_blockscale (torch.Tensor): blockscale factors, e4m3,
@@ -208,9 +217,9 @@ def flashinfer_cutedsl_moe_masked(
208217
"""
209218

210219
# === Assertions on dtypes ===
211-
assert input_global_scale.dtype == torch.float32, (
212-
f"input_global_scale must be float32, got {input_global_scale.dtype}"
213-
)
220+
# assert input_global_scale.dtype == torch.float32, (
221+
# f"input_global_scale must be float32, got {input_global_scale.dtype}"
222+
# )
214223
assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}"
215224
assert w1_blockscale.dtype == torch.float8_e4m3fn, (
216225
f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}"
@@ -231,7 +240,32 @@ def flashinfer_cutedsl_moe_masked(
231240

232241
# === Assertions on shapes ===
233242
n = w2.shape[-1] * 2 # intermediate dimension
234-
num_experts, m, k = hidden_states.shape
243+
if isinstance(hidden_states, tuple):
244+
assert (
245+
input_global_scale is None
246+
), "input_global_scale is needed when input needs quant"
247+
248+
aq = hidden_states[0].view(torch.uint8)
249+
aq_sf = hidden_states[1].view(torch.float8_e4m3fn)
250+
# m, k_by_2, num_experts = aq.shape
251+
num_experts, m, k_by_2 = aq.shape
252+
k = k_by_2 * 2
253+
aq = aq.permute(1,2,0)
254+
else:
255+
num_experts, m, k = hidden_states.shape
256+
257+
assert (
258+
input_global_scale.dtype == torch.float32
259+
), f"input_global_scale must be float32, got {input_global_scale.dtype}"
260+
assert input_global_scale.shape == (
261+
num_experts,
262+
), f"input_global_scale must be (l,), got {input_global_scale.shape}"
263+
264+
aq, aq_sf = scaled_fp4_grouped_quantize(
265+
hidden_states,
266+
masked_m,
267+
input_global_scale,
268+
)
235269

236270
assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}"
237271
assert w1.shape[-1] * 2 == k, (
@@ -242,9 +276,9 @@ def flashinfer_cutedsl_moe_masked(
242276
n // 2,
243277
), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}"
244278

245-
assert input_global_scale.shape == (num_experts,), (
246-
f"input_global_scale must be (l,), got {input_global_scale.shape}"
247-
)
279+
# assert input_global_scale.shape == (num_experts,), (
280+
# f"input_global_scale must be (l,), got {input_global_scale.shape}"
281+
# )
248282
assert w1_alpha.shape == (num_experts,), (
249283
f"w1_alpha must be (l,), got {w1_alpha.shape}"
250284
)
@@ -254,23 +288,31 @@ def flashinfer_cutedsl_moe_masked(
254288
assert w2_alpha.shape == (num_experts,), (
255289
f"w2_alpha must be (l,), got {w2_alpha.shape}"
256290
)
291+
# return
257292

258-
aq, aq_sf = scaled_fp4_grouped_quantize(
259-
hidden_states,
260-
masked_m,
261-
input_global_scale,
262-
)
293+
# aq, aq_sf = scaled_fp4_grouped_quantize(
294+
# hidden_states,
295+
# masked_m,
296+
# input_global_scale,
297+
# )
263298

299+
# workspace = workspace.permute(1, 2, 0) # requirement of kernel
300+
# workspace = torch.empty(
301+
# (num_experts, m, n * 2), dtype=torch.bfloat16, device=aq.device
302+
# )
264303
workspace = workspace.permute(1, 2, 0) # requirement of kernel
265304
sf_vec_size = 16
266305
assert aq_sf.dtype == torch.float8_e4m3fn
267306
assert aq.dtype == torch.uint8
268307
ab_dtype = "float4_e2m1fn"
269308
sf_dtype = "float8_e4m3fn"
270309

271-
c_dtype = get_cute_dtype(hidden_states)
310+
# c_dtype = get_cute_dtype(hidden_states)
311+
c_dtype = "bfloat16"
272312

273313
# Gemm1
314+
# print(aq.shape, aq.dtype)
315+
# print(aq_sf.shape, aq_sf.dtype)
274316
flashinfer_cutedsl_grouped_gemm_nt_masked(
275317
(aq, aq_sf),
276318
(w1.permute(1, 2, 0), w1_blockscale),
@@ -290,7 +332,7 @@ def flashinfer_cutedsl_moe_masked(
290332
masked_m,
291333
a2_global_scale,
292334
)
293-
335+
# return
294336
# Gemm2
295337
out = out.permute(1, 2, 0) # requirement of kernel
296338
flashinfer_cutedsl_grouped_gemm_nt_masked(

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def prepare(
179179
quant_config.block_shape,
180180
is_fp4_scale_swizzled=not self.use_dp,
181181
)
182+
182183
if self.use_dp:
183184
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
184185
[topk_weights, topk_ids, a1q, a1q_scale],

0 commit comments

Comments
 (0)