Skip to content

Commit 7776793

Browse files
lfr-0531litaotju
andauthored
[https://nvbugs/5485325][fix] Add a postprocess to the model engine to fix the CUDA graph warmup issue when using speculative decoding (#7373)
Signed-off-by: Fanrong Li <[email protected]> Co-authored-by: Tao Li @ NVIDIA <[email protected]>
1 parent 26fc7da commit 7776793

File tree

6 files changed

+87
-3
lines changed

6 files changed

+87
-3
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,11 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests):
136136
def needs_capture(self, batch_size: int):
137137
return (batch_size, self.draft_len) not in self.graph_outputs
138138

139-
def capture(self, batch_size: int, forward_fn: Callable,
140-
initial_inputs: Dict[str, Any]):
139+
def capture(self,
140+
batch_size: int,
141+
forward_fn: Callable,
142+
initial_inputs: Dict[str, Any],
143+
postprocess_fn: Optional[Callable] = None):
141144
"""Captures the forward pass for a given batch size."""
142145
engine = self._get_engine()
143146
key = (batch_size, self.draft_len)
@@ -181,8 +184,12 @@ def capture(self, batch_size: int, forward_fn: Callable,
181184
with with_multi_stream(True), piecewise_cuda_graph(False):
182185
for _ in range(self.WARMUP_STEPS):
183186
forward_fn(capture_inputs)
187+
if postprocess_fn is not None:
188+
postprocess_fn(capture_inputs)
184189
with torch.cuda.graph(graph, pool=self.memory_pool):
185190
output = forward_fn(capture_inputs)
191+
if postprocess_fn is not None:
192+
postprocess_fn(capture_inputs)
186193

187194
self.graphs[key] = graph
188195
self.graph_outputs[key] = make_weak_ref(output)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,26 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
11331133
self.previous_kv_lens_offsets_cuda[:num_gen_requests])
11341134
return inputs
11351135

1136+
def _postprocess_inputs(self, inputs: Dict[str, Any]):
1137+
"""
1138+
Postprocess to make sure model forward doesn't change the inputs.
1139+
It is only used in cuda graph capture, because other cases will prepare
1140+
new inputs before the model forward.
1141+
"""
1142+
if self.enable_spec_decode and not self._disable_overlap_scheduler:
1143+
if inputs['attn_metadata'].kv_cache_manager is not None:
1144+
num_seqs = inputs['attn_metadata'].num_seqs
1145+
num_ctx_requests = inputs['attn_metadata'].num_contexts
1146+
num_gen_requests = inputs['attn_metadata'].num_generations
1147+
num_ctx_tokens = inputs['attn_metadata'].num_ctx_tokens
1148+
previous_batch_tokens = inputs['input_ids'].shape[
1149+
0] - num_ctx_tokens
1150+
inputs['position_ids'][0, num_ctx_tokens:] -= (
1151+
self.previous_pos_id_offsets_cuda[:previous_batch_tokens])
1152+
inputs['attn_metadata'].kv_lens_cuda[
1153+
num_ctx_requests:num_seqs] -= (
1154+
self.previous_kv_lens_offsets_cuda[:num_gen_requests])
1155+
11361156
def _get_all_rank_num_tokens(self, attn_metadata: AttentionMetadata):
11371157
if self.enable_attention_dp:
11381158
return list(self.dist.tp_allgather(attn_metadata.num_tokens))
@@ -2206,8 +2226,12 @@ def capture_forward_fn(inputs: Dict[str, Any]):
22062226
gather_ids=gather_ids,
22072227
gather_context_logits=gather_context_logits)
22082228

2229+
def capture_postprocess_fn(inputs: Dict[str, Any]):
2230+
self._postprocess_inputs(inputs)
2231+
22092232
self.cuda_graph_runner.capture(batch_size,
2210-
capture_forward_fn, inputs)
2233+
capture_forward_fn, inputs,
2234+
capture_postprocess_fn)
22112235

22122236
# here we don't need to use context since cuda graph capture didn't run kernel.
22132237
# maybe we need a cleaner way to do this.

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1871,6 +1871,42 @@ def test_nvfp4_multi_gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
18711871
# task.evaluate(llm,
18721872
# extra_evaluator_kwargs=dict(apply_chat_template=True))
18731873

1874+
def test_nvfp4_multi_gpus_corner_case(self):
1875+
"""
1876+
This test is used to test the corner case of the NVFP4 model.
1877+
When using the same value for max_seq_len and max_num_tokens, there will be no
1878+
enough kv block for the dummy requests in CUDA graph warmup when creating
1879+
the py_executor before estimating kv cache. Then CUDA graph capture will be
1880+
triggered when estimating kv cache. This may cause some errors.
1881+
More info in https://nvbugs/5485325.
1882+
"""
1883+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.80,
1884+
dtype="fp8",
1885+
enable_block_reuse=False)
1886+
pytorch_config = dict(disable_overlap_scheduler=False,
1887+
cuda_graph_config=CudaGraphConfig(
1888+
enable_padding=True, max_batch_size=1024),
1889+
moe_config=MoeConfig(backend="TRTLLM"))
1890+
1891+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=1)
1892+
with LLM(f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-FP4",
1893+
tensor_parallel_size=8,
1894+
pipeline_parallel_size=1,
1895+
moe_expert_parallel_size=8,
1896+
kv_cache_config=kv_cache_config,
1897+
**pytorch_config,
1898+
enable_attention_dp=False,
1899+
speculative_config=mtp_config,
1900+
max_seq_len=5120,
1901+
max_num_tokens=5120) as llm:
1902+
1903+
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
1904+
1905+
task = MMLU(self.MODEL_NAME)
1906+
task.evaluate(llm)
1907+
task = GSM8K(self.MODEL_NAME)
1908+
task.evaluate(llm)
1909+
18741910
@pytest.mark.skip_less_mpi_world_size(8)
18751911
@skip_pre_hopper
18761912
@pytest.mark.parametrize(

tests/integration/test_lists/qa/llm_function_full.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_
526526
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput]
527527
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8]
528528
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4]
529+
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case
529530
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput]
530531
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
531532
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency-torch_compile=False]

tests/integration/test_lists/qa/llm_function_sanity.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency]
3535
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4]
3636
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp8]
3737
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput]
38+
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case
3839
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0]
3940
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]
4041
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False]

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,21 @@ l0_dgx_b200:
7070
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (180)
7171
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] TIMEOUT (180)
7272
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp] TIMEOUT (180)
73+
- condition:
74+
ranges:
75+
system_gpu_count:
76+
gte: 8
77+
lte: 8
78+
wildcards:
79+
gpu:
80+
- '*b200*'
81+
linux_distribution_name: ubuntu*
82+
cpu: x86_64
83+
terms:
84+
stage: post_merge
85+
backend: pytorch
86+
tests:
87+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (180)
7388
- condition:
7489
ranges:
7590
system_gpu_count:

0 commit comments

Comments
 (0)