@@ -250,13 +250,21 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
250250        # Convert to tensor for sampling 
251251        valid_token_ids_tensor  =  torch .tensor (valid_token_ids ,
252252                                              dtype = torch .long ,
253-                                               device = "cpu" )
253+                                               device = torch . device ( "cpu" ) )
254254
255255        # Sample from the valid token ids 
256256        warmup_tokens_tensor  =  valid_token_ids_tensor [torch .randint (
257257            0 , len (valid_token_ids_tensor ), (batch_size , prompt_len ))]
258258
259-         # Create requests to be used for prefill steps 
259+         extra_kwargs  =  {}
260+         if  envs_spyre .VLLM_SPYRE_DYNAMO_BACKEND  not  in   [
261+                 "sendnn" , "sendnn_decoder" 
262+         ]:
263+             # Bug in 2.3.1 fixed in 2.4.1 for SDPA flash cpu 
264+             # impl when padding too much 
265+             extra_kwargs ["attn_algorithm" ] =  "math" 
266+ 
267+         # Set up dummy requests for prefill steps 
260268        dummy_requests  =  [
261269            NewRequestData (
262270                req_id = "warmup" ,
@@ -272,7 +280,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
272280            ) for  i  in  range (batch_size )
273281        ]
274282
275-         # Set up dummy cached_requests to be used  for decode steps 
283+         # Set up dummy cached_requests for decode steps 
276284        cached_requests  =  [
277285            CachedRequestData (
278286                req_id = req .req_id ,
@@ -286,8 +294,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
286294            ) for  req  in  dummy_requests 
287295        ]
288296
289-         # To be used for execute_model, start with scheduled_new_reqs 
290-         # for prefill 
297+         # Set up scheduler_output for execute_model 
291298        scheduler_output  =  SchedulerOutput (
292299            scheduled_new_reqs = dummy_requests ,
293300            scheduled_cached_reqs = [],
@@ -303,18 +310,10 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
303310        )
304311
305312        # First full forward pass 
306-         logger .info ("Warmup 1/2: Prefill..." )
307-         self .execute_model (scheduler_output )  # Prefill step 
308- 
309-         # Switch to cached requests to trigger decoding steps 
310-         scheduler_output .scheduled_new_reqs  =  []
311-         scheduler_output .scheduled_cached_reqs  =  cached_requests 
313+         logger .info ("Warmup forward pass 1/2..." )
314+         self ._warmup_model_forward_pass (scheduler_output , dummy_requests ,
315+                                         cached_requests , num_decode_tokens )
312316
313-         logger .info ("Warmup 1/2: Decoding..." )
314-         for  _  in  range (num_decode_tokens  -  1 ):
315-             self .execute_model (scheduler_output )
316- 
317-         # update_lazyhandle 
318317        if  envs_spyre .VLLM_SPYRE_DYNAMO_BACKEND  ==  "sendnn_decoder" :
319318            from  torch_sendnn  import  torch_sendnn 
320319            ul_start_time  =  time .time ()
@@ -324,18 +323,9 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
324323                        ul_stop_time  -  ul_start_time )
325324
326325        # Second full forward pass 
327-         logger .info ("Warmup 2/2: Prefill step..." )
328-         scheduler_output .scheduled_new_reqs  =  dummy_requests 
329-         scheduler_output .scheduled_cached_reqs  =  []
330-         self .execute_model (scheduler_output )
331- 
332-         # Switch to cached requests to trigger decoding steps 
333-         scheduler_output .scheduled_new_reqs  =  []
334-         scheduler_output .scheduled_cached_reqs  =  cached_requests 
335- 
336-         logger .info ("Warmup 2/2: Decoding steps..." )
337-         for  _  in  range (num_decode_tokens  -  1 ):
338-             self .execute_model (scheduler_output )
326+         logger .info ("Warmup forward pass 2/2..." )
327+         self ._warmup_model_forward_pass (scheduler_output , dummy_requests ,
328+                                         cached_requests , num_decode_tokens )
339329
340330        warmup_end_t  =  time .time ()
341331        warmup_total_t  =  warmup_end_t  -  warmup_start_t 
@@ -344,6 +334,24 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
344334            "Warmup took %.3fs (for prompt length %d and max output tokens %d)" ,
345335            warmup_total_t , prompt_len , num_decode_tokens )
346336
337+     def  _warmup_model_forward_pass (
338+         self ,
339+         scheduler_output : SchedulerOutput ,
340+         requests : List [NewRequestData ],
341+         cached_requests : List [CachedRequestData ],
342+         num_decode_tokens ,
343+     ):
344+         """Handle a complete forward pass""" 
345+         scheduler_output .scheduled_new_reqs  =  requests 
346+         scheduler_output .scheduled_cached_reqs  =  []
347+         self .execute_model (scheduler_output )  # Prefill 
348+ 
349+         # Switch to cached requests to trigger decoding steps 
350+         scheduler_output .scheduled_new_reqs  =  []
351+         scheduler_output .scheduled_cached_reqs  =  cached_requests 
352+         for  _  in  range (num_decode_tokens  -  1 ):
353+             self .execute_model (scheduler_output )
354+ 
347355    @property  
348356    def  do_metadata_broadcast (self ) ->  bool :
349357        return  True 
0 commit comments