@@ -114,7 +114,7 @@ def _do_quant(
114114 assert isinstance (x , (torch .Tensor , tuple ))
115115 q_dtype = quant_config .quant_dtype
116116
117- if q_dtype == "nvfp4" :
117+ if q_dtype == "nvfp4" and envs . VLLM_DEEPEPLL_NVFP4_DISPATCH :
118118 assert isinstance (x , tuple )
119119 # num_experts, max_tokens, hidden_dim_by_4 = x[0].size()
120120 # print(f"nvfp4 quantization input shape: {x[0].size()}, {x[0].dtype}, {x[0].is_contiguous()}")
@@ -125,13 +125,18 @@ def _do_quant(
125125 # print(f"nvfp4 quantization input shape: {x.size()}, {x.dtype}, {x.is_contiguous()}")
126126 num_experts , max_tokens , hidden_dim_by_2 = x .shape
127127 hidden_dim = hidden_dim_by_2 * 2
128- assert (envs .VLLM_FLASHINFER_MOE_BACKEND == "cutedsl" )
129- logger .info_once ("skip nvfp4 quant since done by deepep!!" )
130- # logger.info_once(
131- # "Skip quantization when using FlashInfer CUTEDSL for "
132- # "ModelOptNvFp4FusedMoE."
133- # )
128+ assert (envs .VLLM_FLASHINFER_MOE_BACKEND == "cutedsl" )
129+ logger .info_once (
130+ "Quantization is fused with DeepEP nvfp4 dispatch for " \
131+ "FlashInfer CUTEDSL as VLLM_DEEPEPLL_NVFP4_DISPATCH==1"
132+ )
134133 else :
134+ if q_dtype == "nvfp4" :
135+ q_dtype = None
136+ logger .info_once (
137+ "Using DeepEP bfloat16 dispatch for FlashInfer CUTEDSL as " \
138+ "VLLM_DEEPEPLL_NVFP4_DISPATCH==0"
139+ )
135140 assert isinstance (x , torch .Tensor )
136141 num_experts , max_tokens , hidden_dim = x .size ()
137142
@@ -146,7 +151,6 @@ def _do_quant(
146151 )
147152 x = x .view ((num_experts , - 1 , hidden_dim ))
148153
149-
150154 if q_dtype is not None :
151155 assert x_scales is not None
152156 if q_dtype != "nvfp4" :
@@ -182,11 +186,12 @@ def prepare_async(
182186
183187 use_nvfp4 = False
184188 # print("mm"*100, quant_config.quant_dtype)
185- if quant_config .quant_dtype == "nvfp4" :
189+ nvfp4_dispatch = quant_config .quant_dtype == "nvfp4" and envs .VLLM_DEEPEPLL_NVFP4_DISPATCH
190+ if nvfp4_dispatch :
186191 # print("gg"*100)
187192 # print(quant_config.a1_gscale)
188193 use_nvfp4 = True
189- qc_a1_gscale_or_scale = quant_config .a1_gscale if quant_config . quant_dtype == "nvfp4" else quant_config .a1_scale
194+ qc_a1_gscale_or_scale = quant_config .a1_gscale if nvfp4_dispatch else quant_config .a1_scale
190195 has_per_token_scales = (
191196 qc_a1_gscale_or_scale .numel () != 1
192197 if qc_a1_gscale_or_scale is not None
0 commit comments