2929import vllm_spyre .perf_metrics as perf_metrics
3030from vllm_spyre .model_executor .model_loader import spyre_setup
3131from vllm_spyre .platform import SpyrePlatform
32- from vllm_spyre .v1 .worker .spyre_input_batch import SamplingInputBatch
3332from vllm_spyre .v1 .worker .spyre_model_runner import (
3433 ContinuousBatchingSpyreModelRunner , SpyrePoolingModelRunner ,
3534 StaticBatchingSpyreModelRunner , SupportedTask )
@@ -110,6 +109,9 @@ def compile_or_warm_up_model(self) -> None:
110109 prompt_len , num_decode_tokens , batch_size )
111110 self ._warmup_spyre_fixed_size (prompt_len , num_decode_tokens ,
112111 self .restricted_tokens , batch_size )
112+
113+ self .model_runner .complete_warmup ()
114+
113115 all_warmup_end_t = time .time ()
114116 all_warmup_total_t = all_warmup_end_t - all_warmup_start_t
115117 self .perf_metrics .log ("total warmup time" , all_warmup_total_t )
@@ -119,7 +121,6 @@ def compile_or_warm_up_model(self) -> None:
119121 "[WARMUP] All %d prompt/decode/batchsize-shape "
120122 "combinations finished in %.3fs" , num_shape_combinations ,
121123 all_warmup_total_t )
122- self .model_runner .complete_warmup ()
123124
124125 def check_health (self ) -> None :
125126 """Basic health check (override for device-specific checks)."""
@@ -339,18 +340,6 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
339340 prompt_len = 42
340341 num_decode_tokens = 2
341342
342- # Fix for batch size 1: set input batch to fit 2 requests for warmup
343- if model_runner .vllm_config .scheduler_config .max_num_seqs == 1 :
344- model_runner .input_batch = SamplingInputBatch (
345- max_num_reqs = 2 ,
346- max_model_len = model_runner .vllm_config .model_config .
347- max_model_len ,
348- device = model_runner .device ,
349- pin_memory = model_runner .pin_memory ,
350- vocab_size = model_runner .vllm_config .model_config .
351- get_vocab_size (),
352- )
353-
354343 # Sample from the valid token ids
355344 warmup_tokens_tensor = valid_token_ids_tensor [torch .randint (
356345 0 , len (valid_token_ids_tensor ), (batch_size + 1 , prompt_len ))]
@@ -398,20 +387,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
398387 self .execute_model (scheduler_output )
399388 self ._cleanup_model_runner (request = [add_dummy_request ])
400389
401- # Fix for batch size 1: reset input batch to fit max_num_seqs requests
402- if model_runner .vllm_config .scheduler_config .max_num_seqs == 1 :
403- model_runner .input_batch = SamplingInputBatch (
404- max_num_reqs = model_runner .vllm_config .scheduler_config .
405- max_num_seqs ,
406- max_model_len = model_runner .vllm_config .model_config .
407- max_model_len ,
408- device = model_runner .device ,
409- pin_memory = model_runner .pin_memory ,
410- vocab_size = model_runner .vllm_config .model_config .
411- get_vocab_size (),
412- )
413-
414- model_runner .finish_warmup ()
390+ model_runner .complete_warmup ()
415391
416392 warmup_end_t = time .time ()
417393 warmup_total_t = warmup_end_t - warmup_start_t
0 commit comments