Skip to content

Commit 925c15a

Browse files
committed
env var works
1 parent 309e0b8 commit 925c15a

File tree

3 files changed

+27
-16
lines changed

3 files changed

+27
-16
lines changed

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,12 @@ def get_vllm_port() -> int | None:
10571057
"VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool(
10581058
os.environ.get("VLLM_MXFP4_USE_MARLIN", None)
10591059
),
1060+
# Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method
1061+
# only supported on Blackwell GPUs and with
1062+
# https://github.com/deepseek-ai/DeepEP/pull/341
1063+
"VLLM_DEEPEPLL_NVFP4_DISPATCH": lambda: bool(
1064+
int(os.getenv("VLLM_DEEPEPLL_NVFP4_DISPATCH", "0"))
1065+
),
10601066
# Whether to turn on the outlines cache for V0
10611067
# This cache is unbounded and on disk, so it's not safe to use in
10621068
# an environment with potentially malicious users.

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _do_quant(
114114
assert isinstance(x, (torch.Tensor, tuple))
115115
q_dtype = quant_config.quant_dtype
116116

117-
if q_dtype == "nvfp4":
117+
if q_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH:
118118
assert isinstance(x, tuple)
119119
# num_experts, max_tokens, hidden_dim_by_4 = x[0].size()
120120
# print(f"nvfp4 quantization input shape: {x[0].size()}, {x[0].dtype}, {x[0].is_contiguous()}")
@@ -125,13 +125,18 @@ def _do_quant(
125125
# print(f"nvfp4 quantization input shape: {x.size()}, {x.dtype}, {x.is_contiguous()}")
126126
num_experts, max_tokens, hidden_dim_by_2 = x.shape
127127
hidden_dim = hidden_dim_by_2 * 2
128-
assert(envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl")
129-
logger.info_once("skip nvfp4 quant since done by deepep!!")
130-
# logger.info_once(
131-
# "Skip quantization when using FlashInfer CUTEDSL for "
132-
# "ModelOptNvFp4FusedMoE."
133-
# )
128+
assert(envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl")
129+
logger.info_once(
130+
"Quantization is fused with DeepEP nvfp4 dispatch for " \
131+
"FlashInfer CUTEDSL as VLLM_DEEPEPLL_NVFP4_DISPATCH==1"
132+
)
134133
else:
134+
if q_dtype == "nvfp4":
135+
q_dtype = None
136+
logger.info_once(
137+
"Using DeepEP bfloat16 dispatch for FlashInfer CUTEDSL as " \
138+
"VLLM_DEEPEPLL_NVFP4_DISPATCH==0"
139+
)
135140
assert isinstance(x, torch.Tensor)
136141
num_experts, max_tokens, hidden_dim = x.size()
137142

@@ -146,7 +151,6 @@ def _do_quant(
146151
)
147152
x = x.view((num_experts, -1, hidden_dim))
148153

149-
150154
if q_dtype is not None:
151155
assert x_scales is not None
152156
if q_dtype != "nvfp4":
@@ -182,11 +186,12 @@ def prepare_async(
182186

183187
use_nvfp4 = False
184188
# print("mm"*100, quant_config.quant_dtype)
185-
if quant_config.quant_dtype == "nvfp4":
189+
nvfp4_dispatch = quant_config.quant_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH
190+
if nvfp4_dispatch:
186191
# print("gg"*100)
187192
# print(quant_config.a1_gscale)
188193
use_nvfp4 = True
189-
qc_a1_gscale_or_scale = quant_config.a1_gscale if quant_config.quant_dtype == "nvfp4" else quant_config.a1_scale
194+
qc_a1_gscale_or_scale = quant_config.a1_gscale if nvfp4_dispatch else quant_config.a1_scale
190195
has_per_token_scales = (
191196
qc_a1_gscale_or_scale.numel() != 1
192197
if qc_a1_gscale_or_scale is not None

vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
77
from vllm.logger import init_logger
8+
from vllm import envs
89
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
910
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1011
TopKWeightAndReduceDelegate,
@@ -18,8 +19,6 @@
1819

1920
logger = init_logger(__name__)
2021

21-
CUTEDSL_MOE_NVFP4_DISPATCH = True
22-
2322
def is_valid_flashinfer_cutedsl_fused_moe(
2423
hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor
2524
) -> bool:
@@ -110,7 +109,7 @@ def workspace_shapes(
110109
- Note: in order for activation chunking to work, the first dimension
111110
of each tuple must be the number of tokens.
112111
"""
113-
if CUTEDSL_MOE_NVFP4_DISPATCH:
112+
if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH:
114113
# since it sees quantized a1q
115114
K_dim = K * 2
116115
else:
@@ -151,10 +150,11 @@ def apply(
151150
assert self.w1_scale.ndim == 3
152151
assert self.w2_scale.ndim == 3
153152

154-
# TODO(shuw): replace True by CUTEDSL_MOE_NVFP4_DISPATCH
153+
input_global_scale = (None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale)
154+
flashinfer_hidden_states = (hidden_states, a1q_scale) if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else hidden_states
155155
flashinfer_cutedsl_moe_masked(
156-
hidden_states=(hidden_states, a1q_scale) if CUTEDSL_MOE_NVFP4_DISPATCH else hidden_states,
157-
input_global_scale=(None if CUTEDSL_MOE_NVFP4_DISPATCH else self.a1_gscale),
156+
hidden_states=flashinfer_hidden_states,
157+
input_global_scale=input_global_scale,
158158
w1=w1,
159159
w1_blockscale=self.w1_scale,
160160
w1_alpha=self.g1_alphas,

0 commit comments

Comments
 (0)