Skip to content

Commit c9bfad9

Browse files
✅ add assertions for warmup mode context (#294)
# Description This PR adds assertions for operations done within the `warmup context`. Earlier we saw a bug in which the `decode` warmup was moved outside the warmup context because of one simple indentation issue and it caused a segmentation fault with no clue as to why it was happening. These assertions should catch the error when running even on CPU. --------- Signed-off-by: Prashant Gupta <[email protected]>
1 parent a86787d commit c9bfad9

File tree

1 file changed

+80
-58
lines changed

1 file changed

+80
-58
lines changed

vllm_spyre/v1/worker/spyre_worker.py

Lines changed: 80 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,21 @@
3333

3434
logger = init_logger(__name__)
3535

36+
# var to make sure we always warmup with the right context
37+
_inside_warmup_mode = False
38+
3639

3740
@contextlib.contextmanager
3841
def _maybe_warmup_context():
42+
global _inside_warmup_mode
3943
warmup_context = contextlib.nullcontext
4044
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn":
4145
from torch_sendnn import warmup_mode
4246
warmup_context = warmup_mode
4347
with warmup_context():
48+
_inside_warmup_mode = True
4449
yield
50+
_inside_warmup_mode = False
4551

4652

4753
class SpyreWorker(WorkerBaseV1):
@@ -319,7 +325,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
319325
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
320326
0, len(valid_token_ids_tensor), (batch_size + 1, prompt_len))]
321327

322-
dummy_requests = [
328+
dummy_requests: list[NewRequestData] = [
323329
NewRequestData(
324330
req_id="warmup-%d" % (i),
325331
prompt_token_ids=warmup_tokens_tensor[i].tolist(),
@@ -336,63 +342,10 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
336342
add_dummy_request = dummy_requests.pop(-1)
337343

338344
with _maybe_warmup_context():
339-
for i, req in enumerate(dummy_requests):
340-
scheduler_output = SchedulerOutput(
341-
scheduled_new_reqs=[req],
342-
scheduled_cached_reqs=CachedRequestData.make_empty(),
343-
num_scheduled_tokens={req.req_id: prompt_len},
344-
total_num_scheduled_tokens=prompt_len,
345-
scheduled_spec_decode_tokens={},
346-
scheduled_encoder_inputs={},
347-
num_common_prefix_blocks=0,
348-
finished_req_ids=set(),
349-
free_encoder_input_ids=[],
350-
structured_output_request_ids={},
351-
grammar_bitmask=None,
352-
)
353-
logger.info("[WARMUP] Prefill %d/%d...", i + 1, batch_size)
354-
self.execute_model(scheduler_output)
355-
356-
# one decode iteration across all sequences
357-
req_ids = []
358-
new_token_ids = []
359-
new_block_ids = []
360-
num_computed_tokens = []
361-
for req in dummy_requests:
362-
req_ids.append(req.req_id)
363-
new_token_ids.append([
364-
valid_token_ids_tensor[torch.randint(
365-
0, len(valid_token_ids_tensor), (1, )).item()]
366-
]) # placeholder token
367-
new_block_ids.append([req.block_ids])
368-
num_computed_tokens.append(prompt_len)
369-
cached_request_data = CachedRequestData(
370-
req_ids=req_ids,
371-
resumed_from_preemption=False,
372-
new_token_ids=new_token_ids,
373-
new_block_ids=new_block_ids,
374-
num_computed_tokens=num_computed_tokens,
375-
)
376-
377-
scheduler_output = SchedulerOutput(
378-
scheduled_new_reqs=[],
379-
scheduled_cached_reqs=cached_request_data,
380-
num_scheduled_tokens={
381-
f"warmup-{i}": 1
382-
for i in range(batch_size)
383-
},
384-
total_num_scheduled_tokens=batch_size,
385-
scheduled_spec_decode_tokens={},
386-
scheduled_encoder_inputs={},
387-
num_common_prefix_blocks=0,
388-
finished_req_ids=set(),
389-
free_encoder_input_ids=[],
390-
structured_output_request_ids={},
391-
grammar_bitmask=None,
392-
)
393-
logger.info("[WARMUP] Decode...")
394-
self.execute_model(scheduler_output)
395-
self._cleanup_model_runner(request=dummy_requests)
345+
self._dynamic_warmup(dummy_requests=dummy_requests,
346+
prompt_len=prompt_len,
347+
batch_size=batch_size,
348+
valid_token_ids_tensor=valid_token_ids_tensor)
396349

397350
# warmup_mode completes the graph compilation, but we need to do
398351
# one additional prefill to deploy the compiled program to the device,
@@ -560,6 +513,75 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
560513
num_decode_tokens, warmup_total_t, compile_cache_str)
561514
maybe_override_signals_handler()
562515

516+
def _dynamic_warmup(
517+
self,
518+
dummy_requests: list[NewRequestData],
519+
prompt_len: int,
520+
batch_size: int,
521+
valid_token_ids_tensor: torch.Tensor,
522+
) -> None:
523+
524+
assert (
525+
_inside_warmup_mode
526+
), "it looks like you are outside the warmup context for warmup"
527+
528+
for i, req in enumerate(dummy_requests):
529+
scheduler_output = SchedulerOutput(
530+
scheduled_new_reqs=[req],
531+
scheduled_cached_reqs=CachedRequestData.make_empty(),
532+
num_scheduled_tokens={req.req_id: prompt_len},
533+
total_num_scheduled_tokens=prompt_len,
534+
scheduled_spec_decode_tokens={},
535+
scheduled_encoder_inputs={},
536+
num_common_prefix_blocks=0,
537+
finished_req_ids=set(),
538+
free_encoder_input_ids=[],
539+
structured_output_request_ids={},
540+
grammar_bitmask=None,
541+
)
542+
logger.info("[WARMUP] Prefill %d/%d...", i + 1, batch_size)
543+
544+
self.execute_model(scheduler_output)
545+
546+
# one decode iteration across all sequences
547+
req_ids = []
548+
new_token_ids = []
549+
new_block_ids = []
550+
num_computed_tokens = []
551+
for req in dummy_requests:
552+
req_ids.append(req.req_id)
553+
new_token_ids.append([
554+
valid_token_ids_tensor[torch.randint(
555+
0, len(valid_token_ids_tensor), (1, )).item()]
556+
]) # placeholder token
557+
new_block_ids.append([req.block_ids])
558+
num_computed_tokens.append(prompt_len)
559+
cached_request_data = CachedRequestData(
560+
req_ids=req_ids,
561+
resumed_from_preemption=False,
562+
new_token_ids=new_token_ids,
563+
new_block_ids=new_block_ids,
564+
num_computed_tokens=num_computed_tokens,
565+
)
566+
567+
scheduler_output = SchedulerOutput(
568+
scheduled_new_reqs=[],
569+
scheduled_cached_reqs=cached_request_data,
570+
num_scheduled_tokens={f"warmup-{i}": 1
571+
for i in range(batch_size)},
572+
total_num_scheduled_tokens=batch_size,
573+
scheduled_spec_decode_tokens={},
574+
scheduled_encoder_inputs={},
575+
num_common_prefix_blocks=0,
576+
finished_req_ids=set(),
577+
free_encoder_input_ids=[],
578+
structured_output_request_ids={},
579+
grammar_bitmask=None,
580+
)
581+
logger.info("[WARMUP] Decode...")
582+
self.execute_model(scheduler_output)
583+
self._cleanup_model_runner(request=dummy_requests)
584+
563585
def _warmup_model_forward_pass(
564586
self,
565587
scheduler_output: SchedulerOutput,

0 commit comments

Comments
 (0)