Skip to content

Commit 0ceea20

Browse files
authored
🐛 fixed static batch warmup (#246)
# Description This changes the warmup for static batching back to how it originally was, only warming up a single pass. This fixes a bug where the compiled model graphs were incorrect- we would invoke the model with one batch size and the model would output a different batch size. Signed-off-by: Joe Runde <[email protected]>
1 parent 9511982 commit 0ceea20

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

vllm_spyre/v1/worker/spyre_worker.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,8 @@ def compile_or_warm_up_model(self) -> None:
9090
logger.info(
9191
"Warming up for prompt length %d, decoding %d tokens with "
9292
"batch size %d", prompt_len, num_decode_tokens, batch_size)
93-
with _maybe_warmup_context():
94-
self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens,
95-
self.restricted_tokens,
96-
batch_size)
93+
self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens,
94+
self.restricted_tokens, batch_size)
9795
all_warmup_end_t = time.time()
9896
all_warmup_total_t = all_warmup_end_t - all_warmup_start_t
9997
self.perf_metrics.log("total warmup time", all_warmup_total_t)
@@ -537,8 +535,10 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
537535

538536
# First full forward pass
539537
logger.info("Warmup forward pass 1/2...")
540-
self._warmup_model_forward_pass(scheduler_output, dummy_requests,
541-
cached_requests, num_decode_tokens)
538+
# The fixed size warmup needs to happen only in here
539+
with _maybe_warmup_context():
540+
self._warmup_model_forward_pass(scheduler_output, dummy_requests,
541+
cached_requests, num_decode_tokens)
542542
self.perf_metrics.log("warmup 1 time",
543543
time.time() - warmup_start_t,
544544
batch_size=batch_size,

0 commit comments

Comments
 (0)