|
33 | 33 |
|
34 | 34 | logger = init_logger(__name__) |
35 | 35 |
|
| 36 | +# var to make sure we always warmup with the right context |
| 37 | +_inside_warmup_mode = False |
| 38 | + |
36 | 39 |
|
37 | 40 | @contextlib.contextmanager |
38 | 41 | def _maybe_warmup_context(): |
| 42 | + global _inside_warmup_mode |
39 | 43 | warmup_context = contextlib.nullcontext |
40 | 44 | if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn": |
41 | 45 | from torch_sendnn import warmup_mode |
42 | 46 | warmup_context = warmup_mode |
43 | 47 | with warmup_context(): |
| 48 | + _inside_warmup_mode = True |
44 | 49 | yield |
| 50 | + _inside_warmup_mode = False |
45 | 51 |
|
46 | 52 |
|
47 | 53 | class SpyreWorker(WorkerBaseV1): |
@@ -351,6 +357,8 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): |
351 | 357 | grammar_bitmask=None, |
352 | 358 | ) |
353 | 359 | logger.info("[WARMUP] Prefill %d/%d...", i + 1, batch_size) |
| 360 | + assert _inside_warmup_mode, \ |
| 361 | + "it looks like you are outside the warmup context for prefill" |
354 | 362 | self.execute_model(scheduler_output) |
355 | 363 |
|
356 | 364 | # one decode iteration across all sequences |
@@ -391,6 +399,8 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): |
391 | 399 | grammar_bitmask=None, |
392 | 400 | ) |
393 | 401 | logger.info("[WARMUP] Decode...") |
| 402 | + assert _inside_warmup_mode, \ |
| 403 | + "it looks like you are outside the warmup context for decode" |
394 | 404 | self.execute_model(scheduler_output) |
395 | 405 | self._cleanup_model_runner(request=dummy_requests) |
396 | 406 |
|
|
0 commit comments