3333
3434logger = init_logger (__name__ )
3535
36+ # var to make sure we always warmup with the right context
37+ _inside_warmup_mode = False
38+
3639
3740@contextlib .contextmanager
3841def _maybe_warmup_context ():
42+ global _inside_warmup_mode
3943 warmup_context = contextlib .nullcontext
4044 if envs_spyre .VLLM_SPYRE_DYNAMO_BACKEND == "sendnn" :
4145 from torch_sendnn import warmup_mode
4246 warmup_context = warmup_mode
4347 with warmup_context ():
48+ _inside_warmup_mode = True
4449 yield
50+ _inside_warmup_mode = False
4551
4652
4753class SpyreWorker (WorkerBaseV1 ):
@@ -319,7 +325,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
319325 warmup_tokens_tensor = valid_token_ids_tensor [torch .randint (
320326 0 , len (valid_token_ids_tensor ), (batch_size + 1 , prompt_len ))]
321327
322- dummy_requests = [
328+ dummy_requests : list [ NewRequestData ] = [
323329 NewRequestData (
324330 req_id = "warmup-%d" % (i ),
325331 prompt_token_ids = warmup_tokens_tensor [i ].tolist (),
@@ -336,63 +342,10 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
336342 add_dummy_request = dummy_requests .pop (- 1 )
337343
338344 with _maybe_warmup_context ():
339- for i , req in enumerate (dummy_requests ):
340- scheduler_output = SchedulerOutput (
341- scheduled_new_reqs = [req ],
342- scheduled_cached_reqs = CachedRequestData .make_empty (),
343- num_scheduled_tokens = {req .req_id : prompt_len },
344- total_num_scheduled_tokens = prompt_len ,
345- scheduled_spec_decode_tokens = {},
346- scheduled_encoder_inputs = {},
347- num_common_prefix_blocks = 0 ,
348- finished_req_ids = set (),
349- free_encoder_input_ids = [],
350- structured_output_request_ids = {},
351- grammar_bitmask = None ,
352- )
353- logger .info ("[WARMUP] Prefill %d/%d..." , i + 1 , batch_size )
354- self .execute_model (scheduler_output )
355-
356- # one decode iteration across all sequences
357- req_ids = []
358- new_token_ids = []
359- new_block_ids = []
360- num_computed_tokens = []
361- for req in dummy_requests :
362- req_ids .append (req .req_id )
363- new_token_ids .append ([
364- valid_token_ids_tensor [torch .randint (
365- 0 , len (valid_token_ids_tensor ), (1 , )).item ()]
366- ]) # placeholder token
367- new_block_ids .append ([req .block_ids ])
368- num_computed_tokens .append (prompt_len )
369- cached_request_data = CachedRequestData (
370- req_ids = req_ids ,
371- resumed_from_preemption = False ,
372- new_token_ids = new_token_ids ,
373- new_block_ids = new_block_ids ,
374- num_computed_tokens = num_computed_tokens ,
375- )
376-
377- scheduler_output = SchedulerOutput (
378- scheduled_new_reqs = [],
379- scheduled_cached_reqs = cached_request_data ,
380- num_scheduled_tokens = {
381- f"warmup-{ i } " : 1
382- for i in range (batch_size )
383- },
384- total_num_scheduled_tokens = batch_size ,
385- scheduled_spec_decode_tokens = {},
386- scheduled_encoder_inputs = {},
387- num_common_prefix_blocks = 0 ,
388- finished_req_ids = set (),
389- free_encoder_input_ids = [],
390- structured_output_request_ids = {},
391- grammar_bitmask = None ,
392- )
393- logger .info ("[WARMUP] Decode..." )
394- self .execute_model (scheduler_output )
395- self ._cleanup_model_runner (request = dummy_requests )
345+ self ._dynamic_warmup (dummy_requests = dummy_requests ,
346+ prompt_len = prompt_len ,
347+ batch_size = batch_size ,
348+ valid_token_ids_tensor = valid_token_ids_tensor )
396349
397350 # warmup_mode completes the graph compilation, but we need to do
398351 # one additional prefill to deploy the compiled program to the device,
@@ -560,6 +513,75 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
560513 num_decode_tokens , warmup_total_t , compile_cache_str )
561514 maybe_override_signals_handler ()
562515
516+ def _dynamic_warmup (
517+ self ,
518+ dummy_requests : list [NewRequestData ],
519+ prompt_len : int ,
520+ batch_size : int ,
521+ valid_token_ids_tensor : torch .Tensor ,
522+ ) -> None :
523+
524+ assert (
525+ _inside_warmup_mode
526+ ), "it looks like you are outside the warmup context for warmup"
527+
528+ for i , req in enumerate (dummy_requests ):
529+ scheduler_output = SchedulerOutput (
530+ scheduled_new_reqs = [req ],
531+ scheduled_cached_reqs = CachedRequestData .make_empty (),
532+ num_scheduled_tokens = {req .req_id : prompt_len },
533+ total_num_scheduled_tokens = prompt_len ,
534+ scheduled_spec_decode_tokens = {},
535+ scheduled_encoder_inputs = {},
536+ num_common_prefix_blocks = 0 ,
537+ finished_req_ids = set (),
538+ free_encoder_input_ids = [],
539+ structured_output_request_ids = {},
540+ grammar_bitmask = None ,
541+ )
542+ logger .info ("[WARMUP] Prefill %d/%d..." , i + 1 , batch_size )
543+
544+ self .execute_model (scheduler_output )
545+
546+ # one decode iteration across all sequences
547+ req_ids = []
548+ new_token_ids = []
549+ new_block_ids = []
550+ num_computed_tokens = []
551+ for req in dummy_requests :
552+ req_ids .append (req .req_id )
553+ new_token_ids .append ([
554+ valid_token_ids_tensor [torch .randint (
555+ 0 , len (valid_token_ids_tensor ), (1 , )).item ()]
556+ ]) # placeholder token
557+ new_block_ids .append ([req .block_ids ])
558+ num_computed_tokens .append (prompt_len )
559+ cached_request_data = CachedRequestData (
560+ req_ids = req_ids ,
561+ resumed_from_preemption = False ,
562+ new_token_ids = new_token_ids ,
563+ new_block_ids = new_block_ids ,
564+ num_computed_tokens = num_computed_tokens ,
565+ )
566+
567+ scheduler_output = SchedulerOutput (
568+ scheduled_new_reqs = [],
569+ scheduled_cached_reqs = cached_request_data ,
570+ num_scheduled_tokens = {f"warmup-{ i } " : 1
571+ for i in range (batch_size )},
572+ total_num_scheduled_tokens = batch_size ,
573+ scheduled_spec_decode_tokens = {},
574+ scheduled_encoder_inputs = {},
575+ num_common_prefix_blocks = 0 ,
576+ finished_req_ids = set (),
577+ free_encoder_input_ids = [],
578+ structured_output_request_ids = {},
579+ grammar_bitmask = None ,
580+ )
581+ logger .info ("[WARMUP] Decode..." )
582+ self .execute_model (scheduler_output )
583+ self ._cleanup_model_runner (request = dummy_requests )
584+
563585 def _warmup_model_forward_pass (
564586 self ,
565587 scheduler_output : SchedulerOutput ,
0 commit comments