@@ -82,117 +82,6 @@ def compile_or_warm_up_model(self) -> None:
8282 "combinations finished. Total warmup time %.3fs." ,
8383 len (wup_new_tokens ), all_warmup_total_t )
8484
85- def _warmup_spyre_fixed_size (self , prompt_len , num_decode_tokens ,
86- special_token_ids , batch_size ):
87-
88- warmup_start_t = time .time ()
89- # NOTE(ngl): empty tensor causes spyre to hang, so using
90- # randint without 0 and the eos and bos token
91-
92- # Create a list of valid values between 1 (inclusive) and vocab
93- # size (exclusive) by excluding the eos and bos token ids
94- # (in special_token_ids)
95- vocab_size = self .model_runner .vocab_size
96- valid_token_ids = [
97- i for i in range (1 , vocab_size ) if i not in set (special_token_ids )
98- ]
99- # Convert to tensor for sampling
100- valid_token_ids_tensor = torch .tensor (valid_token_ids ,
101- dtype = torch .long ,
102- device = "cpu" )
103-
104- # Sample from the valid token ids
105- warmup_tokens_tensor = valid_token_ids_tensor [torch .randint (
106- 0 , len (valid_token_ids_tensor ), (batch_size , prompt_len ))]
107-
108- # Create requests to be used for prefill steps
109- dummy_requests = [
110- NewRequestData (
111- req_id = "warmup" ,
112- prompt_token_ids = warmup_tokens_tensor [i ].tolist (),
113- prompt = "test" ,
114- mm_inputs = [],
115- mm_hashes = [],
116- mm_positions = [],
117- sampling_params = SamplingParams (max_tokens = num_decode_tokens ),
118- block_ids = [0 ],
119- num_computed_tokens = 0 ,
120- lora_request = None ,
121- ) for i in range (batch_size )
122- ]
123-
124- # Set up dummy cached_requests to be used for decode steps
125- cached_requests = [
126- CachedRequestData (
127- req_id = req .req_id ,
128- resumed_from_preemption = False ,
129- new_token_ids = [
130- valid_token_ids_tensor [torch .randint (
131- 0 , len (valid_token_ids_tensor ), (1 , )).item ()]
132- ], # placeholder token
133- new_block_ids = req .block_ids ,
134- num_computed_tokens = req .num_computed_tokens ,
135- ) for req in dummy_requests
136- ]
137-
138- # To be used for execute_model, start with scheduled_new_reqs
139- # for prefill
140- scheduler_output = SchedulerOutput (
141- scheduled_new_reqs = dummy_requests ,
142- scheduled_cached_reqs = [],
143- num_scheduled_tokens = {i : prompt_len
144- for i in range (batch_size )},
145- total_num_scheduled_tokens = sum (prompt_len
146- for _ in range (batch_size )),
147- scheduled_spec_decode_tokens = {},
148- scheduled_encoder_inputs = {},
149- num_common_prefix_blocks = 0 ,
150- finished_req_ids = set (),
151- free_encoder_input_ids = [],
152- )
153-
154- # First full forward pass
155- logger .info ("Warmup 1/2: Prefill..." )
156- self .execute_model (scheduler_output ) # Prefill step
157-
158- # Switch to cached requests to trigger decoding steps
159- scheduler_output .scheduled_new_reqs = []
160- scheduler_output .scheduled_cached_reqs = cached_requests
161-
162- logger .info ("Warmup 1/2: Decoding..." )
163- for _ in range (num_decode_tokens - 1 ):
164- self .execute_model (scheduler_output )
165-
166- # update_lazyhandle
167- if envs_spyre .VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder" :
168- from torch_sendnn import torch_sendnn
169- ul_start_time = time .time ()
170- torch_sendnn .update_lazyhandle ()
171- ul_stop_time = time .time ()
172- logger .info ("update_lazyhandle() done (duration: %.3fs" ,
173- ul_stop_time - ul_start_time )
174-
175- # Second full forward pass
176- logger .info ("Warmup 2/2: Prefill step..." )
177- scheduler_output .scheduled_new_reqs = dummy_requests
178- scheduler_output .scheduled_cached_reqs = []
179- self .execute_model (scheduler_output )
180-
181- # Switch to cached requests to trigger decoding steps
182- scheduler_output .scheduled_new_reqs = []
183- scheduler_output .scheduled_cached_reqs = cached_requests
184-
185- logger .info ("[Warmup 2/2: Decoding steps..." )
186- for _ in range (num_decode_tokens - 1 ):
187- self .execute_model (scheduler_output )
188-
189- warmup_end_t = time .time ()
190- warmup_total_t = warmup_end_t - warmup_start_t
191- logger .info ("Warmup finished." )
192- logger .info (
193- "Warmup took %.3fs (for prompt length %d and max output tokens %d)" ,
194- warmup_total_t , prompt_len , num_decode_tokens )
195-
19685 def check_health (self ) -> None :
19786 """Basic health check (override for device-specific checks)."""
19887 # TODO: Implement something!
@@ -344,6 +233,117 @@ def load_model(self):
344233 load_model_total_t = load_model_end_t - load_model_start_t
345234 logger .info ("load model took %.3fs" , load_model_total_t )
346235
236+ def _warmup_spyre_fixed_size (self , prompt_len , num_decode_tokens ,
237+ special_token_ids , batch_size ):
238+
239+ warmup_start_t = time .time ()
240+ # NOTE(ngl): empty tensor causes spyre to hang, so using
241+ # randint without 0 and the eos and bos token
242+
243+ # Create a list of valid values between 1 (inclusive) and vocab
244+ # size (exclusive) by excluding the eos and bos token ids
245+ # (in special_token_ids)
246+ vocab_size = self .model_runner .vocab_size
247+ valid_token_ids = [
248+ i for i in range (1 , vocab_size ) if i not in set (special_token_ids )
249+ ]
250+ # Convert to tensor for sampling
251+ valid_token_ids_tensor = torch .tensor (valid_token_ids ,
252+ dtype = torch .long ,
253+ device = "cpu" )
254+
255+ # Sample from the valid token ids
256+ warmup_tokens_tensor = valid_token_ids_tensor [torch .randint (
257+ 0 , len (valid_token_ids_tensor ), (batch_size , prompt_len ))]
258+
259+ # Create requests to be used for prefill steps
260+ dummy_requests = [
261+ NewRequestData (
262+ req_id = "warmup" ,
263+ prompt_token_ids = warmup_tokens_tensor [i ].tolist (),
264+ prompt = "test" ,
265+ mm_inputs = [],
266+ mm_hashes = [],
267+ mm_positions = [],
268+ sampling_params = SamplingParams (max_tokens = num_decode_tokens ),
269+ block_ids = [0 ],
270+ num_computed_tokens = 0 ,
271+ lora_request = None ,
272+ ) for i in range (batch_size )
273+ ]
274+
275+ # Set up dummy cached_requests to be used for decode steps
276+ cached_requests = [
277+ CachedRequestData (
278+ req_id = req .req_id ,
279+ resumed_from_preemption = False ,
280+ new_token_ids = [
281+ valid_token_ids_tensor [torch .randint (
282+ 0 , len (valid_token_ids_tensor ), (1 , )).item ()]
283+ ], # placeholder token
284+ new_block_ids = req .block_ids ,
285+ num_computed_tokens = req .num_computed_tokens ,
286+ ) for req in dummy_requests
287+ ]
288+
289+ # To be used for execute_model, start with scheduled_new_reqs
290+ # for prefill
291+ scheduler_output = SchedulerOutput (
292+ scheduled_new_reqs = dummy_requests ,
293+ scheduled_cached_reqs = [],
294+ num_scheduled_tokens = {i : prompt_len
295+ for i in range (batch_size )},
296+ total_num_scheduled_tokens = sum (prompt_len
297+ for _ in range (batch_size )),
298+ scheduled_spec_decode_tokens = {},
299+ scheduled_encoder_inputs = {},
300+ num_common_prefix_blocks = 0 ,
301+ finished_req_ids = set (),
302+ free_encoder_input_ids = [],
303+ )
304+
305+ # First full forward pass
306+ logger .info ("Warmup 1/2: Prefill..." )
307+ self .execute_model (scheduler_output ) # Prefill step
308+
309+ # Switch to cached requests to trigger decoding steps
310+ scheduler_output .scheduled_new_reqs = []
311+ scheduler_output .scheduled_cached_reqs = cached_requests
312+
313+ logger .info ("Warmup 1/2: Decoding..." )
314+ for _ in range (num_decode_tokens - 1 ):
315+ self .execute_model (scheduler_output )
316+
317+ # update_lazyhandle
318+ if envs_spyre .VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder" :
319+ from torch_sendnn import torch_sendnn
320+ ul_start_time = time .time ()
321+ torch_sendnn .update_lazyhandle ()
322+ ul_stop_time = time .time ()
323+ logger .info ("update_lazyhandle() done (duration: %.3fs" ,
324+ ul_stop_time - ul_start_time )
325+
326+ # Second full forward pass
327+ logger .info ("Warmup 2/2: Prefill step..." )
328+ scheduler_output .scheduled_new_reqs = dummy_requests
329+ scheduler_output .scheduled_cached_reqs = []
330+ self .execute_model (scheduler_output )
331+
332+ # Switch to cached requests to trigger decoding steps
333+ scheduler_output .scheduled_new_reqs = []
334+ scheduler_output .scheduled_cached_reqs = cached_requests
335+
336+ logger .info ("[Warmup 2/2: Decoding steps..." )
337+ for _ in range (num_decode_tokens - 1 ):
338+ self .execute_model (scheduler_output )
339+
340+ warmup_end_t = time .time ()
341+ warmup_total_t = warmup_end_t - warmup_start_t
342+ logger .info ("Warmup finished." )
343+ logger .info (
344+ "Warmup took %.3fs (for prompt length %d and max output tokens %d)" ,
345+ warmup_total_t , prompt_len , num_decode_tokens )
346+
347347 @property
348348 def do_metadata_broadcast (self ) -> bool :
349349 return True
0 commit comments