diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 45d63bccd..76ad21fd8 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -459,27 +459,6 @@ def pad_input_ids( return input_ids, position_ids, mask - def _raw_model_forward( - self, - input_ids: torch.Tensor, - mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_value_states: Optional[List[Tuple[torch.Tensor, - torch.Tensor]]] = None, - use_cache: bool = False, - only_last_token: bool = False, - attn_algorithm: Optional[str] = None - ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, - torch.Tensor]]]]: - - return self.model.model(input_ids, - mask=mask, - position_ids=position_ids, - past_key_value_states=past_key_value_states, - use_cache=use_cache, - only_last_token=only_last_token, - attn_algorithm=attn_algorithm) - def get_kv_cache_spec(self) -> KVCacheSpec: """ This method should generate the KVCache spec by parsing the kv cache diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 78c06a1bb..6fca9cd37 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -14,7 +14,9 @@ from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform -from vllm.v1.core.scheduler import SchedulerOutput +from vllm.sampling_params import SamplingParams +from vllm.v1.core.scheduler import (CachedRequestData, NewRequestData, + SchedulerOutput) from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerBase as WorkerBaseV1 @@ -233,12 +235,7 @@ def load_model(self): def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, special_token_ids, batch_size): - # TODO See if we can use `self.execute_model` instead for the warmup - # It's slightly risky to implement different forward pass logic here, - # which can go out of sync with the real forward pass and cause problems - # for torch.compile - # warmup the model warmup_start_t = time.time() # NOTE(ngl): empty tensor causes spyre to hang, so using # randint without 0 and the eos and bos token @@ -254,6 +251,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, valid_token_ids_tensor = torch.tensor(valid_token_ids, dtype=torch.long, device=torch.device("cpu")) + # Sample from the valid token ids warmup_tokens_tensor = valid_token_ids_tensor[torch.randint( 0, len(valid_token_ids_tensor), (batch_size, prompt_len))] @@ -266,82 +264,93 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, # impl when padding too much extra_kwargs["attn_algorithm"] = "math" - print(f"[SpyreWorker] warmup for prompt length " - f"{prompt_len} and max output tokens {num_decode_tokens}.") + # Set up dummy requests for prefill steps + dummy_requests = [ + NewRequestData( + req_id="warmup", + prompt_token_ids=warmup_tokens_tensor[i].tolist(), + prompt="test", + mm_inputs=[], + mm_hashes=[], + mm_positions=[], + sampling_params=SamplingParams(max_tokens=num_decode_tokens), + block_ids=[0], + num_computed_tokens=0, + lora_request=None, + ) for i in range(batch_size) + ] + + # Set up dummy cached_requests for decode steps + cached_requests = [ + CachedRequestData( + req_id=req.req_id, + resumed_from_preemption=False, + new_token_ids=[ + valid_token_ids_tensor[torch.randint( + 0, len(valid_token_ids_tensor), (1, )).item()] + ], # placeholder token + new_block_ids=req.block_ids, + num_computed_tokens=req.num_computed_tokens, + ) for req in dummy_requests + ] - # 1. trace - print("[SpyreWorker] warmup 1/2...") - # TODO: torch_sendnn.CleanGraph() should be necessary? - # warmup 1st forward pass - self._warmup_model_forward_pass(warmup_tokens_tensor, - valid_token_ids_tensor, prompt_len, - num_decode_tokens, batch_size, - extra_kwargs) + # Set up scheduler_output for execute_model + scheduler_output = SchedulerOutput( + scheduled_new_reqs=dummy_requests, + scheduled_cached_reqs=[], + num_scheduled_tokens={i: prompt_len + for i in range(batch_size)}, + total_num_scheduled_tokens=sum(prompt_len + for _ in range(batch_size)), + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + ) + + # First full forward pass + logger.info("Warmup forward pass 1/2...") + self._warmup_model_forward_pass(scheduler_output, dummy_requests, + cached_requests, num_decode_tokens) - # 2. compile - print("[SpyreWorker] warmup 2/2...") if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder": from torch_sendnn import torch_sendnn ul_start_time = time.time() torch_sendnn.update_lazyhandle() ul_stop_time = time.time() - ul_total_t = ul_stop_time - ul_start_time - print(f"update_lazyhandle() done (duration: {ul_total_t}s)") + logger.info("update_lazyhandle() done (duration: %.3fs)", + ul_stop_time - ul_start_time) - # warmup 2nd forward pass - self._warmup_model_forward_pass(warmup_tokens_tensor, - valid_token_ids_tensor, prompt_len, - num_decode_tokens, batch_size, - extra_kwargs) + # Second full forward pass + logger.info("Warmup forward pass 2/2...") + self._warmup_model_forward_pass(scheduler_output, dummy_requests, + cached_requests, num_decode_tokens) warmup_end_t = time.time() warmup_total_t = warmup_end_t - warmup_start_t - print("[SpyreWorker] ... warmup finished.") - print(f"\twarmup took {warmup_total_t}s (for prompt length" - f"{prompt_len} and max output tokens {num_decode_tokens})") - - def _warmup_model_forward_pass(self, warmup_tokens_tensor, - valid_token_ids_tensor, prompt_len, - num_decode_tokens, batch_size, - extra_kwargs): - # padding warmup tokens to obtain the - # corresponding position ids and mask - warmup_tokens_pad, self.model_runner._position_ids, \ - self.model_runner._mask = self.model_runner.pad_input_ids( - warmup_tokens_tensor, min_pad_length=prompt_len) - - logits, past_key_value_states = self.model_runner._raw_model_forward( - warmup_tokens_pad, - position_ids=self.model_runner._position_ids, - mask=self.model_runner._mask, - past_key_value_states=None, - use_cache=True, - only_last_token=True, - **extra_kwargs) - # decoding - for i in range(num_decode_tokens - 1): - # sampling next input token from vocab without bos and eos tokens - decode_tokens = valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (batch_size, 1))] - - # update mask and position_ids - self.model_runner._update_mask() - self.model_runner._update_position_ids() - - if past_key_value_states is not None: - for layer in past_key_value_states: - for tensor in layer: - torch._dynamo.mark_dynamic(tensor, 2) - - logits, past_key_value_states = self.model_runner.\ - _raw_model_forward( - decode_tokens, - position_ids=self.model_runner._position_ids, - mask=self.model_runner._mask, - past_key_value_states=past_key_value_states, - use_cache=True, - only_last_token=True, - **extra_kwargs) + logger.info("Warmup finished.") + logger.info( + "Warmup took %.3fs (for prompt length %d and max output tokens %d)", + warmup_total_t, prompt_len, num_decode_tokens) + + def _warmup_model_forward_pass( + self, + scheduler_output: SchedulerOutput, + requests: List[NewRequestData], + cached_requests: List[CachedRequestData], + num_decode_tokens, + ): + """Handle a complete forward pass""" + scheduler_output.scheduled_new_reqs = requests + scheduler_output.scheduled_cached_reqs = [] + self.execute_model(scheduler_output) # Prefill + + # Switch to cached requests to trigger decoding steps + scheduler_output.scheduled_new_reqs = [] + scheduler_output.scheduled_cached_reqs = cached_requests + for _ in range(num_decode_tokens - 1): + self.execute_model(scheduler_output) @property def do_metadata_broadcast(self) -> bool: