1515from vllm .model_executor import set_random_seed
1616from vllm .platforms import current_platform
1717from vllm .sampling_params import SamplingParams
18- from vllm .v1 .core .scheduler import CachedRequestData , NewRequestData , SchedulerOutput
18+ from vllm .v1 .core .scheduler import (CachedRequestData , NewRequestData ,
19+ SchedulerOutput )
1920from vllm .v1 .kv_cache_interface import KVCacheConfig , KVCacheSpec
2021from vllm .v1 .outputs import ModelRunnerOutput
2122from vllm .v1 .worker .worker_base import WorkerBase as WorkerBaseV1
@@ -81,9 +82,8 @@ def compile_or_warm_up_model(self) -> None:
8182 "combinations finished. Total warmup time %.3fs." ,
8283 len (wup_new_tokens ), all_warmup_total_t )
8384
84-
8585 def _warmup_spyre_fixed_size (self , prompt_len , num_decode_tokens ,
86- special_token_ids , batch_size ):
86+ special_token_ids , batch_size ):
8787
8888 warmup_start_t = time .time ()
8989 # NOTE(ngl): empty tensor causes spyre to hang, so using
@@ -97,7 +97,9 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
9797 i for i in range (1 , vocab_size ) if i not in set (special_token_ids )
9898 ]
9999 # Convert to tensor for sampling
100- valid_token_ids_tensor = torch .tensor (valid_token_ids , dtype = torch .long , device = "cpu" )
100+ valid_token_ids_tensor = torch .tensor (valid_token_ids ,
101+ dtype = torch .long ,
102+ device = "cpu" )
101103
102104 # Sample from the valid token ids
103105 warmup_tokens_tensor = valid_token_ids_tensor [torch .randint (
@@ -106,10 +108,12 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
106108 # Create requests to be used for prefill steps
107109 dummy_requests = [
108110 NewRequestData (
109- req_id = f "warmup" ,
111+ req_id = "warmup" ,
110112 prompt_token_ids = warmup_tokens_tensor [i ].tolist (),
111113 prompt = "test" ,
112- mm_inputs = [], mm_hashes = [], mm_positions = [],
114+ mm_inputs = [],
115+ mm_hashes = [],
116+ mm_positions = [],
113117 sampling_params = SamplingParams (max_tokens = num_decode_tokens ),
114118 block_ids = [0 ],
115119 num_computed_tokens = 0 ,
@@ -122,8 +126,10 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
122126 CachedRequestData (
123127 req_id = req .req_id ,
124128 resumed_from_preemption = False ,
125- new_token_ids = [valid_token_ids_tensor [torch .randint (
126- 0 , len (valid_token_ids_tensor ), (1 ,)).item ()]], # placeholder token
129+ new_token_ids = [
130+ valid_token_ids_tensor [torch .randint (
131+ 0 , len (valid_token_ids_tensor ), (1 , )).item ()]
132+ ], # placeholder token
127133 new_block_ids = req .block_ids ,
128134 num_computed_tokens = req .num_computed_tokens ,
129135 ) for req in dummy_requests
@@ -134,8 +140,10 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
134140 scheduler_output = SchedulerOutput (
135141 scheduled_new_reqs = dummy_requests ,
136142 scheduled_cached_reqs = [],
137- num_scheduled_tokens = {i : prompt_len for i in range (batch_size )},
138- total_num_scheduled_tokens = sum (prompt_len for _ in range (batch_size )),
143+ num_scheduled_tokens = {i : prompt_len
144+ for i in range (batch_size )},
145+ total_num_scheduled_tokens = sum (prompt_len
146+ for _ in range (batch_size )),
139147 scheduled_spec_decode_tokens = {},
140148 scheduled_encoder_inputs = {},
141149 num_common_prefix_blocks = 0 ,
@@ -144,14 +152,14 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
144152 )
145153
146154 # First full forward pass
147- logger .info ("[SpyreWorker] Warmup 1/2: Prefill..." )
155+ logger .info ("Warmup 1/2: Prefill..." )
148156 self .execute_model (scheduler_output ) # Prefill step
149157
150158 # Switch to cached requests to trigger decoding steps
151159 scheduler_output .scheduled_new_reqs = []
152160 scheduler_output .scheduled_cached_reqs = cached_requests
153161
154- logger .info ("[SpyreWorker] Warmup 1/2: Decoding..." )
162+ logger .info ("Warmup 1/2: Decoding..." )
155163 for _ in range (num_decode_tokens - 1 ):
156164 self .execute_model (scheduler_output )
157165
@@ -161,10 +169,11 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
161169 ul_start_time = time .time ()
162170 torch_sendnn .update_lazyhandle ()
163171 ul_stop_time = time .time ()
164- logger .info (f"update_lazyhandle() done (duration: { ul_stop_time - ul_start_time } s)" )
172+ logger .info ("update_lazyhandle() done (duration: %.3fs" ,
173+ ul_stop_time - ul_start_time )
165174
166175 # Second full forward pass
167- logger .info ("[SpyreWorker] Warmup 2/2: Prefill step..." )
176+ logger .info ("Warmup 2/2: Prefill step..." )
168177 scheduler_output .scheduled_new_reqs = dummy_requests
169178 scheduler_output .scheduled_cached_reqs = []
170179 self .execute_model (scheduler_output )
@@ -173,16 +182,16 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
173182 scheduler_output .scheduled_new_reqs = []
174183 scheduler_output .scheduled_cached_reqs = cached_requests
175184
176- logger .info ("[SpyreWorker] Warmup 2/2: Decoding steps..." )
185+ logger .info ("[Warmup 2/2: Decoding steps..." )
177186 for _ in range (num_decode_tokens - 1 ):
178187 self .execute_model (scheduler_output )
179188
180189 warmup_end_t = time .time ()
181190 warmup_total_t = warmup_end_t - warmup_start_t
182- logger .info ("[SpyreWorker] ... warmup finished." )
183- logger .info (f" \t warmup took { warmup_total_t } s (for prompt length"
184- f" { prompt_len } and max output tokens { num_decode_tokens } )" )
185-
191+ logger .info ("Warmup finished." )
192+ logger .info (
193+ "Warmup took %.3fs (for prompt length %d and max output tokens %d)" ,
194+ warmup_total_t , prompt_len , num_decode_tokens )
186195
187196 def check_health (self ) -> None :
188197 """Basic health check (override for device-specific checks)."""
0 commit comments