Skip to content

Commit 26fc7da

Browse files
yunruislitaotju
andauthored
[None][opt] Add batch waiting when scheduling (#7287)
Signed-off-by: yunruis <[email protected]> Co-authored-by: Tao Li @ NVIDIA <[email protected]>
1 parent 49b457c commit 26fc7da

File tree

7 files changed

+126
-1
lines changed

7 files changed

+126
-1
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ def __init__(
137137
self.pytorch_backend_config.attention_dp_time_out_iters = 50
138138
self.pytorch_backend_config.attention_dp_batching_wait_iters = 10
139139
self.pytorch_backend_config.batch_wait_timeout_ms = 0
140+
self.pytorch_backend_config.batch_wait_timeout_iters = 0
141+
self.pytorch_backend_config.batch_wait_max_tokens_ratio = 0.0
142+
self.pytorch_backend_config.max_num_tokens = 8192
140143
self.iter_counter = 0
141144

142145
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ class PyTorchConfig:
5050
attention_dp_time_out_iters: int = 50
5151
attention_dp_batching_wait_iters: int = 10
5252

53+
max_num_tokens: int = 8192
54+
5355
batch_wait_timeout_ms: float = 0
56+
batch_wait_timeout_iters: int = 0
57+
batch_wait_max_tokens_ratio: float = 0
5458

5559
attn_backend: str = 'TRTLLM'
5660
moe_backend: str = 'CUTLASS'

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def __init__(
184184
self.active = True
185185
self.max_beam_width = max_beam_width
186186
self.max_draft_len = max_draft_len
187+
self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens
187188
self.print_log = model_engine.pytorch_backend_config.print_iter_log
188189
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
189190
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
@@ -192,6 +193,10 @@ def __init__(
192193
self.attention_dp_time_out_iters = model_engine.pytorch_backend_config.attention_dp_time_out_iters
193194
self.attention_dp_batching_wait_iters = model_engine.pytorch_backend_config.attention_dp_batching_wait_iters
194195
self.batch_wait_timeout_ms = model_engine.pytorch_backend_config.batch_wait_timeout_ms
196+
self.batch_wait_timeout_iters = model_engine.pytorch_backend_config.batch_wait_timeout_iters
197+
self.batch_wait_max_tokens_ratio = model_engine.pytorch_backend_config.batch_wait_max_tokens_ratio
198+
self.enable_batch_waiting = self.batch_wait_timeout_iters > 0 or self.batch_wait_max_tokens_ratio > 0
199+
195200
self.num_fetch_requests_cur_rank = 0
196201
self.num_fetch_requests = 0
197202
self.shutdown_event = threading.Event()
@@ -236,6 +241,7 @@ def __init__(
236241
self.max_batch_size = max_batch_size
237242
self.adp_ctx_waiting_iters_count = 0
238243
self.adp_ctx_batching_wait_iters_count = 0
244+
self.batch_wait_iters_count = 0
239245

240246
# request fetcher initialization
241247
self.executor_request_queue = ExecutorRequestQueue(
@@ -1334,6 +1340,27 @@ def _balance_adp_requests(self, context_requests: list[LlmRequest],
13341340
balanced_context_requests = context_requests
13351341
return balanced_context_requests
13361342

1343+
def _waiting_requests(self, context_requests: list[LlmRequest],
1344+
generation_requests: list[LlmRequest]):
1345+
if not self.enable_batch_waiting:
1346+
return context_requests
1347+
1348+
waited_context_requests = []
1349+
stop_waiting = False
1350+
num_scheduled_ctx_tokens = sum(
1351+
len(ctx_req.get_tokens(0)) for ctx_req in context_requests)
1352+
num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens
1353+
for gen_req in generation_requests)
1354+
num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens
1355+
1356+
stop_waiting = self.batch_wait_iters_count >= self.batch_wait_timeout_iters or num_scheduled_tokens >= self.batch_wait_max_tokens_ratio * self.max_num_tokens
1357+
if stop_waiting:
1358+
waited_context_requests = context_requests
1359+
self.batch_wait_iters_count = 0
1360+
else:
1361+
self.batch_wait_iters_count += 1
1362+
return waited_context_requests
1363+
13371364
@nvtx_range("_schedule")
13381365
def _schedule(self):
13391366
scheduler_output = self.scheduler.schedule_request(
@@ -1344,6 +1371,14 @@ def _schedule(self):
13441371
scheduler_output.context_requests,
13451372
scheduler_output.generation_requests)
13461373

1374+
# if no generation requests, no need to wait, to avoid dead waiting
1375+
if not self.enable_attention_dp and self.enable_batch_waiting and len(
1376+
scheduler_output.context_requests) > 0 and len(
1377+
scheduler_output.generation_requests) > 0:
1378+
scheduled_context_requests = self._waiting_requests(
1379+
scheduler_output.context_requests,
1380+
scheduler_output.generation_requests)
1381+
13471382
scheduled_requests = ScheduledRequests()
13481383
scheduled_requests.context_requests = scheduled_context_requests
13491384
scheduled_requests.generation_requests = scheduler_output.generation_requests

tensorrt_llm/llmapi/llm_args.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2255,6 +2255,18 @@ class TorchLlmArgs(BaseLlmArgs):
22552255
"If greater than 0, the request queue might wait up to batch_wait_timeout_ms to receive max_batch_size requests, if fewer than max_batch_size requests are currently available. If 0, no waiting occurs.",
22562256
status="prototype")
22572257

2258+
batch_wait_timeout_iters: int = Field(
2259+
default=0,
2260+
description=
2261+
"Maximum number of iterations the scheduler will wait to accumulate new coming requests for improved GPU utilization efficiency. If greater than 0, the scheduler will delay batch processing to gather more requests up to the specified iteration limit. If 0, disables timeout-iters-based batching delays.",
2262+
status="prototype")
2263+
2264+
batch_wait_max_tokens_ratio: float = Field(
2265+
default=0,
2266+
description=
2267+
"Token accumulation threshold ratio for batch scheduling optimization. If greater than 0, the scheduler will accumulate requests locally until the total token count reaches batch_wait_max_tokens_ratio * max_num_tokens. This mechanism enhances GPU utilization efficiency by ensuring adequate batch sizes.If 0 disables token-based batching delays.",
2268+
status="prototype")
2269+
22582270
torch_compile_config: Optional[TorchCompileConfig] = Field(
22592271
default=None, description="Torch compile config.", status="prototype")
22602272

@@ -2528,6 +2540,22 @@ def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs':
25282540
raise ValueError("batch_wait_timeout_ms must be greater than 0")
25292541
return self
25302542

2543+
@model_validator(mode='after')
2544+
def validate_batch_wait_timeout_iters(self) -> 'TorchLlmArgs':
2545+
if self.batch_wait_timeout_iters < 0:
2546+
raise ValueError(
2547+
f"batch_wait_timeout_iters must be >= 0, got {self.batch_wait_timeout_iters}"
2548+
)
2549+
return self
2550+
2551+
@model_validator(mode='after')
2552+
def validate_batch_wait_max_tokens_ratio(self) -> 'TorchLlmArgs':
2553+
if self.batch_wait_max_tokens_ratio < 0 or self.batch_wait_max_tokens_ratio > 1:
2554+
raise ValueError(
2555+
f"batch_wait_max_tokens_ratio must be in range [0, 1], got {self.batch_wait_max_tokens_ratio}"
2556+
)
2557+
return self
2558+
25312559
def get_executor_config(
25322560
self,
25332561
_hf_model_dir: Optional[Path] = None,
@@ -2603,7 +2631,10 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
26032631
attention_dp_batching_wait_iters=self.attention_dp_config.
26042632
batching_wait_iters if self.attention_dp_config is not None else
26052633
AttentionDpConfig.model_fields['batching_wait_iters'].default,
2606-
batch_wait_timeout_ms=self.batch_wait_timeout_ms)
2634+
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
2635+
batch_wait_timeout_iters=self.batch_wait_timeout_iters,
2636+
batch_wait_max_tokens_ratio=self.batch_wait_max_tokens_ratio,
2637+
)
26072638

26082639

26092640
def update_llm_args_with_extra_dict(

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,6 +1546,49 @@ def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
15461546
task = GSM8K(self.MODEL_NAME)
15471547
task.evaluate(llm)
15481548

1549+
@skip_pre_blackwell
1550+
@parametrize_with_ids("torch_compile", [False, True])
1551+
@parametrize_with_ids("fp8kv,cuda_graph,overlap_scheduler",
1552+
[(False, False, False), (True, True, True)])
1553+
@parametrize_with_ids("mtp_nextn", [0, 2])
1554+
@parametrize_with_ids(
1555+
"batch_wait_timeout_iters,batch_wait_max_tokens_ratio", [(0, 0),
1556+
(10, 0.75),
1557+
(10, 0),
1558+
(0, 0.75)])
1559+
def test_nvfp4_batch_waiting(self, torch_compile, fp8kv, cuda_graph,
1560+
overlap_scheduler, mtp_nextn,
1561+
batch_wait_timeout_iters,
1562+
batch_wait_max_tokens_ratio):
1563+
moe_backend = "CUTLASS"
1564+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
1565+
torch_compile_config = TorchCompileConfig(
1566+
enable_fullgraph=True,
1567+
enable_piecewise_cuda_graph=cuda_graph,
1568+
capture_num_tokens=[2048, 8192],
1569+
max_num_streams=3) if torch_compile else None
1570+
pytorch_config = dict(
1571+
disable_overlap_scheduler=not overlap_scheduler,
1572+
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
1573+
torch_compile_config=torch_compile_config,
1574+
batch_wait_timeout_iters=batch_wait_timeout_iters,
1575+
batch_wait_max_tokens_ratio=batch_wait_max_tokens_ratio,
1576+
moe_config=MoeConfig(backend=moe_backend))
1577+
mtp_config = None
1578+
if mtp_nextn > 0:
1579+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
1580+
if fp8kv:
1581+
kv_cache_config.dtype = "fp8"
1582+
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only_mtp",
1583+
kv_cache_config=kv_cache_config,
1584+
**pytorch_config,
1585+
enable_attention_dp=False,
1586+
speculative_config=mtp_config) as llm:
1587+
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
1588+
1589+
task = GSM8K(self.MODEL_NAME)
1590+
task.evaluate(llm)
1591+
15491592
@pytest.mark.skip_less_device(4)
15501593
@skip_pre_blackwell
15511594
@parametrize_with_ids("torch_compile", [False, True])

tests/integration/test_lists/qa/llm_function_full.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,7 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-
512512
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
513513
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
514514
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
515+
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_batch_waiting[batch_wait_timeout_iters=10-batch_wait_max_tokens_ratio=0.75-mtp_nextn=0-fp8kv=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]
515516
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb
516517
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0]
517518
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]

tests/unittest/api_stability/references/llm.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@ methods:
131131
annotation: float
132132
default: 0
133133
status: prototype
134+
batch_wait_timeout_iters:
135+
annotation: int
136+
default: 0
137+
status: prototype
138+
batch_wait_max_tokens_ratio:
139+
annotation: float
140+
default: 0
141+
status: prototype
134142
print_iter_log:
135143
annotation: bool
136144
default: False

0 commit comments

Comments
 (0)