Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,21 @@

logger = init_logger(__name__)

# var to make sure we always warmup with the right context
_inside_warmup_mode = False


@contextlib.contextmanager
def _maybe_warmup_context():
global _inside_warmup_mode
warmup_context = contextlib.nullcontext
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn":
from torch_sendnn import warmup_mode
warmup_context = warmup_mode
with warmup_context():
_inside_warmup_mode = True
yield
_inside_warmup_mode = False


class SpyreWorker(WorkerBaseV1):
Expand Down Expand Up @@ -351,6 +357,8 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
grammar_bitmask=None,
)
logger.info("[WARMUP] Prefill %d/%d...", i + 1, batch_size)
assert _inside_warmup_mode, \
"it looks like you are outside the warmup context for prefill"
self.execute_model(scheduler_output)

# one decode iteration across all sequences
Expand Down Expand Up @@ -391,6 +399,8 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
grammar_bitmask=None,
)
logger.info("[WARMUP] Decode...")
assert _inside_warmup_mode, \
"it looks like you are outside the warmup context for decode"
self.execute_model(scheduler_output)
self._cleanup_model_runner(request=dummy_requests)

Expand Down
Loading