-
Notifications
You must be signed in to change notification settings - Fork 26
Use execute_model for warmup #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
e79c317
3a627a5
2603096
8e94e8a
74c0f60
d29de93
e5177ac
99ac9d4
e555ecf
79ee4c4
aca9c1f
972fc87
b9c03b6
4a48c7b
a6b8e25
45cec8f
daa6ded
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,8 @@ | |
| init_distributed_environment) | ||
| from vllm.model_executor import set_random_seed | ||
| from vllm.platforms import current_platform | ||
| from vllm.v1.core.scheduler import SchedulerOutput | ||
| from vllm.sampling_params import SamplingParams | ||
| from vllm.v1.core.scheduler import NewRequestData, SchedulerOutput | ||
| from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, | ||
| KVCacheSpec) | ||
| from vllm.v1.outputs import ModelRunnerOutput | ||
|
|
@@ -68,8 +69,51 @@ def compile_or_warm_up_model(self) -> None: | |
| print(f"[SpyreWorker] Warming up for prompt length {prompt_len}, " | ||
| f"decoding {num_decode_tokens} tokens with batch " | ||
| f"size {batch_size}") | ||
| self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens, | ||
| self.restricted_tokens, batch_size) | ||
|
|
||
| num_scheduled_tokens: dict = {} | ||
| total_num_scheduled_tokens: int = 0 | ||
| dummy_requests: list = [] | ||
| for i in range(batch_size): | ||
| dummy_requests.append( | ||
| NewRequestData( | ||
| req_id=f"warmup-{i}", | ||
| prompt_token_ids=[1] * prompt_len, | ||
|
||
| prompt="test", | ||
| mm_inputs=[], | ||
| mm_hashes=[], | ||
| mm_positions=[], | ||
| sampling_params=SamplingParams( | ||
| max_tokens=num_decode_tokens), | ||
| block_ids=[0], | ||
| num_computed_tokens=0, | ||
| lora_request=None, | ||
| )) | ||
| num_scheduled_tokens[i] = prompt_len | ||
| total_num_scheduled_tokens += num_scheduled_tokens[i] | ||
|
|
||
| scheduler_output = SchedulerOutput( | ||
| scheduled_new_reqs=dummy_requests, | ||
| scheduled_cached_reqs=[], | ||
| num_scheduled_tokens=num_scheduled_tokens, | ||
| total_num_scheduled_tokens=total_num_scheduled_tokens, | ||
| scheduled_spec_decode_tokens={}, | ||
| scheduled_encoder_inputs={}, | ||
| num_common_prefix_blocks=0, | ||
| finished_req_ids=set(), | ||
| free_encoder_input_ids=[], | ||
| ) | ||
|
|
||
| # Use execute_model for warm up | ||
| self.execute_model(scheduler_output) | ||
|
||
|
|
||
| if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder": | ||
| from torch_sendnn import torch_sendnn | ||
| ul_start_time = time.time() | ||
| torch_sendnn.update_lazyhandle() | ||
| ul_stop_time = time.time() | ||
| ul_total_t = ul_stop_time - ul_start_time | ||
| print(f"update_lazyhandle() done (duration: {ul_total_t}s)") | ||
|
|
||
|
||
| all_warmup_end_t = time.time() | ||
| all_warmup_total_t = all_warmup_end_t - all_warmup_start_t | ||
| print(f"[SpyreWorker] All warmups for " | ||
|
|
@@ -209,118 +253,6 @@ def load_model(self): | |
| load_model_total_t = load_model_end_t - load_model_start_t | ||
| print(f"\tload model took {load_model_total_t}s") | ||
|
|
||
| def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, | ||
| special_token_ids, batch_size): | ||
| # TODO See if we can use `self.execute_model` instead for the warmup | ||
| # It's slightly risky to implement different forward pass logic here, | ||
| # which can go out of sync with the real forward pass and cause problems | ||
| # for torch.compile | ||
|
|
||
| # warmup the model | ||
| warmup_start_t = time.time() | ||
| # NOTE(ngl): empty tensor causes spyre to hang, so using | ||
| # randint without 0 and the eos and bos token | ||
|
|
||
| # Create a list of valid values between 1 (inclusive) and vocab | ||
| # size (exclusive) by excluding the eos and bos token ids | ||
| # (in special_token_ids) | ||
| vocab_size = self.model_runner.vocab_size | ||
| valid_token_ids = [ | ||
| i for i in range(1, vocab_size) if i not in set(special_token_ids) | ||
| ] | ||
| # Convert to tensor for sampling | ||
| valid_token_ids_tensor = torch.tensor(valid_token_ids, | ||
| dtype=torch.long, | ||
| device=torch.device("cpu")) | ||
| # Sample from the valid token ids | ||
| warmup_tokens_tensor = valid_token_ids_tensor[torch.randint( | ||
| 0, len(valid_token_ids_tensor), (batch_size, prompt_len))] | ||
|
|
||
| extra_kwargs = {} | ||
| if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND not in [ | ||
| "sendnn", "sendnn_decoder" | ||
| ]: | ||
| # Bug in 2.3.1 fixed in 2.4.1 for SDPA flash cpu | ||
| # impl when padding too much | ||
| extra_kwargs["attn_algorithm"] = "math" | ||
|
|
||
| print(f"[SpyreWorker] warmup for prompt length " | ||
| f"{prompt_len} and max output tokens {num_decode_tokens}.") | ||
|
|
||
| # 1. trace | ||
| print("[SpyreWorker] warmup 1/2...") | ||
| # TODO: torch_sendnn.CleanGraph() should be necessary? | ||
| # warmup 1st forward pass | ||
| self._warmup_model_forward_pass(warmup_tokens_tensor, | ||
| valid_token_ids_tensor, prompt_len, | ||
| num_decode_tokens, batch_size, | ||
| extra_kwargs) | ||
|
|
||
| # 2. compile | ||
| print("[SpyreWorker] warmup 2/2...") | ||
| if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder": | ||
| from torch_sendnn import torch_sendnn | ||
| ul_start_time = time.time() | ||
| torch_sendnn.update_lazyhandle() | ||
| ul_stop_time = time.time() | ||
| ul_total_t = ul_stop_time - ul_start_time | ||
| print(f"update_lazyhandle() done (duration: {ul_total_t}s)") | ||
|
|
||
| # warmup 2nd forward pass | ||
| self._warmup_model_forward_pass(warmup_tokens_tensor, | ||
| valid_token_ids_tensor, prompt_len, | ||
| num_decode_tokens, batch_size, | ||
| extra_kwargs) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this has to happen for each warmup shape (combination of |
||
|
|
||
| warmup_end_t = time.time() | ||
| warmup_total_t = warmup_end_t - warmup_start_t | ||
| print("[SpyreWorker] ... warmup finished.") | ||
| print(f"\twarmup took {warmup_total_t}s (for prompt length" | ||
| f"{prompt_len} and max output tokens {num_decode_tokens})") | ||
|
|
||
| def _warmup_model_forward_pass(self, warmup_tokens_tensor, | ||
| valid_token_ids_tensor, prompt_len, | ||
| num_decode_tokens, batch_size, | ||
| extra_kwargs): | ||
| # padding warmup tokens to obtain the | ||
| # corresponding position ids and mask | ||
| warmup_tokens_pad, self.model_runner._position_ids, \ | ||
| self.model_runner._mask = self.model_runner.pad_input_ids( | ||
| warmup_tokens_tensor, min_pad_length=prompt_len) | ||
|
|
||
| logits, past_key_value_states = self.model_runner._raw_model_forward( | ||
| warmup_tokens_pad, | ||
| position_ids=self.model_runner._position_ids, | ||
| mask=self.model_runner._mask, | ||
| past_key_value_states=None, | ||
| use_cache=True, | ||
| only_last_token=True, | ||
| **extra_kwargs) | ||
| # decoding | ||
| for i in range(num_decode_tokens - 1): | ||
| # sampling next input token from vocab without bos and eos tokens | ||
| decode_tokens = valid_token_ids_tensor[torch.randint( | ||
| 0, len(valid_token_ids_tensor), (batch_size, 1))] | ||
|
|
||
| # update mask and position_ids | ||
| self.model_runner._update_mask() | ||
| self.model_runner._update_position_ids() | ||
|
|
||
| if past_key_value_states is not None: | ||
| for layer in past_key_value_states: | ||
| for tensor in layer: | ||
| torch._dynamo.mark_dynamic(tensor, 2) | ||
|
|
||
| logits, past_key_value_states = self.model_runner.\ | ||
| _raw_model_forward( | ||
| decode_tokens, | ||
| position_ids=self.model_runner._position_ids, | ||
| mask=self.model_runner._mask, | ||
| past_key_value_states=past_key_value_states, | ||
| use_cache=True, | ||
| only_last_token=True, | ||
| **extra_kwargs) | ||
|
|
||
| def determine_num_available_blocks(self) -> Tuple[int, int]: | ||
| """Determine the number of available KV blocks. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
personally, I like the idea of a helper function here to make things more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, specifically avoiding double-nested for loops is nice to do