Skip to content

Commit a6b8e25

Browse files
joerunderafvasq
authored andcommitted
🐛 fix batch handling in V1 runner (#33)
* 🐛 fix batch handling in V1 runner Signed-off-by: Joe Runde <[email protected]> * ⚗️ try v1 test only Signed-off-by: Joe Runde <[email protected]> * ⚗️ add a bit more prompt Signed-off-by: Joe Runde <[email protected]> * ⚗️ unclear why CI won't count to 0 Signed-off-by: Joe Runde <[email protected]> * ♻️ rename map_output_indices Signed-off-by: Joe Runde <[email protected]> --------- Signed-off-by: Joe Runde <[email protected]>
1 parent 4a48c7b commit a6b8e25

File tree

4 files changed

+88
-3
lines changed

4 files changed

+88
-3
lines changed

tests/test_spyre_basic.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,52 @@ def test_output(
7272
tensor_parallel_size=1,
7373
backend=backend,
7474
vllm_results=vllm_results,
75-
hf_results=hf_results)
75+
hf_results=hf_results)
76+
77+
78+
@pytest.mark.parametrize("model", get_spyre_model_list())
79+
@pytest.mark.parametrize("backend", get_spyre_backend_list())
80+
@pytest.mark.parametrize("vllm_version", ["V0", "V1"])
81+
def test_batch_handling(
82+
model: str,
83+
backend: str,
84+
vllm_version: str,
85+
):
86+
"""Test that the spyre worker correctly handles batches of requests that
87+
finish after different numbers of forward passes"""
88+
89+
# Test with batch size 4
90+
warmup_shape = (64, 20, 4)
91+
92+
# Have the model count down to one and stop
93+
vllm_sampling_params = SamplingParams(max_tokens=20,
94+
temperature=0,
95+
stop="1",
96+
logprobs=0)
97+
# Importantly, these prompts are ordered so that they don't finish in the
98+
# order given
99+
prompts = [
100+
"7 6 5 4",
101+
"10 9 8 7",
102+
"8 7 6 5",
103+
"9 8 7 6",
104+
]
105+
106+
# Ensure that both:
107+
# - The model doesn't crash
108+
# - The output sequences are correct
109+
vllm_results = generate_spyre_vllm_output(
110+
model=model,
111+
prompts=prompts,
112+
warmup_shapes=[warmup_shape],
113+
max_model_len=2048,
114+
block_size=2048,
115+
sampling_params=vllm_sampling_params,
116+
tensor_parallel_size=1,
117+
backend=backend,
118+
vllm_version=vllm_version)
119+
120+
assert vllm_results[0]["text"] == " 3 2 "
121+
assert vllm_results[1]["text"] == " 6 5 4 3 2 "
122+
assert vllm_results[2]["text"] == " 4 3 2 "
123+
assert vllm_results[3]["text"] == " 5 4 3 2 "

vllm_spyre/platform.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,15 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
102102
# the scheduler always thinks there's a block available
103103
model_config.max_model_len = max_seq_len
104104
cache_config.block_size = model_config.max_model_len
105-
cache_config.num_gpu_blocks_override = scheduler_config.max_num_seqs
105+
106+
if envs.VLLM_USE_V1:
107+
# The V1 scheduler actually needs 2 blocks for each sequence...
108+
cache_config.num_gpu_blocks_override = \
109+
scheduler_config.max_num_seqs * 2
110+
else:
111+
cache_config.num_gpu_blocks_override = \
112+
scheduler_config.max_num_seqs
113+
106114
logger.info(
107115
"Overriding configurations based on warmup shapes. "
108116
"max_model_len=%d, max_num_seqs=%d, block_size=%d, "

vllm_spyre/v1/core/scheduler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ def schedule(self) -> "SchedulerOutput":
106106
# can work with the batch we have
107107
break
108108

109+
logger.debug(
110+
"Scheduling a new batch of %d requests, holding back %d "
111+
"requests", len(self.waiting), len(self.holdback_queue))
112+
else:
113+
logger.debug("Scheduling a running batch of %d requests",
114+
len(self.running))
115+
109116
outputs = super().schedule()
110117
return outputs
111118

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def execute_model(
366366

367367
model_output = ModelRunnerOutput(
368368
req_ids=list(self._req_ids2idx.keys()),
369-
req_id_to_index=self._req_ids2idx,
369+
req_id_to_index=self._get_unpadded_output_indices(),
370370
sampled_token_ids=output.sampled_token_ids.tolist(),
371371
spec_token_ids=None,
372372
logprobs=output.logprobs_tensors.tolists()
@@ -378,6 +378,28 @@ def execute_model(
378378
)
379379
return model_output
380380

381+
def _get_unpadded_output_indices(self) -> dict[str, int]:
382+
"""The inputs to the model are all padded to a constant batch size, and
383+
self.req_id2idx is the map of request id -> padded index.
384+
However, finished requests and padded requests are stripped from the
385+
output, so the mapping of request id -> unpadded output index needs to
386+
be created to be returned in `ModelRunnerOutput`.
387+
388+
For example if:
389+
- self.model.indices = [F, T, T, F]
390+
- self.req_ids2ix = {"A": 0, "B": 1, "C": 2, "D": 3}
391+
This will output: {"B": 0, "C": 1}
392+
"""
393+
remapped_indices = {}
394+
for req_id, idx in self._req_ids2idx.items():
395+
if self.model.indices[idx]:
396+
# Sum up all the requests to the left of this one that are still
397+
# processing. That should be this requests' index in the output
398+
# tensor.
399+
remapped_indices[req_id] = self.model.indices[0:idx].sum(
400+
).item()
401+
return remapped_indices
402+
381403
def _prepare_pad_input_ids(
382404
self,
383405
input_ids_list: List[torch.Tensor],

0 commit comments

Comments
 (0)