-
Notifications
You must be signed in to change notification settings - Fork 26
Use execute_model for warmup #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Rafael Vasquez <[email protected]>
👋 Hi! Thank you for contributing to vLLM support on Spyre.
Now you are good to go 🚀 |
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Can |
Signed-off-by: Rafael Vasquez <[email protected]>
print(f"[SpyreWorker] Warming up for prompt length {prompt_len}, " | ||
f"decoding {num_decode_tokens} tokens with batch " | ||
f"size {batch_size}") | ||
self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
personally, I like the idea of a helper function here to make things more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, specifically avoiding double-nested for loops is nice to do
vllm_spyre/v1/worker/spyre_worker.py
Outdated
dummy_requests.append( | ||
NewRequestData( | ||
req_id=f"warmup-{i}", | ||
prompt_token_ids=[1] * prompt_len, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We previously sampled these tokens from a list of valid tokens:
valid_token_ids = [ i for i in range(1, vocab_size) if i not in set(special_token_ids)]
where special_token_ids
contains BOS, EOS and pad token ids. Not sure whether this was needed, or what happens if in your case any of the above special token ids are 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hah, for the very first dummy model I checked this is true: https://huggingface.co/JackFram/llama-160m/blob/main/config.json#L6
I think vllm uses a bunch of repeated token id 0 for profiling, since the input ids tensor is just initialized with torch.zeros
and for text-only models it's not updated for profiling.
The general idea with setting a repeated token ID was to have the model continue the sequence, so it doesn't end up hitting an eos token early and stopping. But if we control the loop here that keeps invoking the model, maybe that doesn't matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I also suspect it does not matter as we force the decode steps in the loop, but rather save than sorry:)
vllm_spyre/v1/worker/spyre_worker.py
Outdated
) | ||
|
||
# Use execute_model for warm up | ||
self.execute_model(scheduler_output) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I understand this executes the model only for one step. In your case it does the prefill and generates the 1st token. We need to warm up not only for prefill, but also for (num_decode_tokens - 1)
decode steps (since prefill already produced 1 token).
vllm_spyre/v1/worker/spyre_worker.py
Outdated
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder": | ||
from torch_sendnn import torch_sendnn | ||
ul_start_time = time.time() | ||
torch_sendnn.update_lazyhandle() | ||
ul_stop_time = time.time() | ||
ul_total_t = ul_stop_time - ul_start_time | ||
print(f"update_lazyhandle() done (duration: {ul_total_t}s)") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After torch_sendnn.update_lazyhandle()
there is a second complete warmup needed.
To sum up:
- complete forward pass: prefill plus
(num_decode_tokens - 1)
decode steps - torch_sendnn.update_lazyhandle()
- complete forward pass: prefill plus
(num_decode_tokens - 1)
decode steps
See comment below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying this, learning as I go but it makes sense to me now.
I took another stab at it, still trying to use execute_model
to avoid doing anything manually except the dummy data setup.
# 1. trace | ||
print("[SpyreWorker] warmup 1/2...") | ||
# TODO: torch_sendnn.CleanGraph() should be necessary? | ||
# warmup 1st forward pass | ||
self._warmup_model_forward_pass(warmup_tokens_tensor, | ||
valid_token_ids_tensor, prompt_len, | ||
num_decode_tokens, batch_size, | ||
extra_kwargs) | ||
|
||
# 2. compile | ||
print("[SpyreWorker] warmup 2/2...") | ||
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder": | ||
from torch_sendnn import torch_sendnn | ||
ul_start_time = time.time() | ||
torch_sendnn.update_lazyhandle() | ||
ul_stop_time = time.time() | ||
ul_total_t = ul_stop_time - ul_start_time | ||
print(f"update_lazyhandle() done (duration: {ul_total_t}s)") | ||
|
||
# warmup 2nd forward pass | ||
self._warmup_model_forward_pass(warmup_tokens_tensor, | ||
valid_token_ids_tensor, prompt_len, | ||
num_decode_tokens, batch_size, | ||
extra_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this has to happen for each warmup shape (combination of prompt_len
, num_decode_tokens
, batch_size
)
Unfortunately, I could not yet tried your changes because of problems in my dev environment. But, besides the feedback from other reviewers, I feel that the final code should have more comments, specially the review comments pointed out by @yannicks1 for better understanding later. |
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LPGTM!
I'll let @yannicks1 take another look through since he has a better understanding of the requirements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the two minor comments this looks good to me! Thanks for contributing!
PS: I am mildly confused why the spyre-tests did not succeed.
tests/test_spyre_embeddings.py
fails with V0. But as your code changes only touch V1, the reason has got to be somewhere else...
vllm_spyre/v1/worker/spyre_worker.py
Outdated
logger.info("Warmup 1/2: Prefill...") | ||
self.execute_model(scheduler_output) # Prefill step | ||
|
||
# Switch to cached requests to trigger decoding steps | ||
scheduler_output.scheduled_new_reqs = [] | ||
scheduler_output.scheduled_cached_reqs = cached_requests | ||
|
||
logger.info("Warmup 1/2: Decoding...") | ||
for _ in range(num_decode_tokens - 1): | ||
self.execute_model(scheduler_output) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally, I am for the use of helper functions wherever they help reducing code duplication. I am aware its just 10 lines here, but they could be eliminated by reusing the _warmup_model_forward_pass
we introduced in the original implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I re-introduced _warmup_model_forward_pass
to handle the duplicate pass code
* 🐛 fix batch handling in V1 runner Signed-off-by: Joe Runde <[email protected]> * ⚗️ try v1 test only Signed-off-by: Joe Runde <[email protected]> * ⚗️ add a bit more prompt Signed-off-by: Joe Runde <[email protected]> * ⚗️ unclear why CI won't count to 0 Signed-off-by: Joe Runde <[email protected]> * ♻️ rename map_output_indices Signed-off-by: Joe Runde <[email protected]> --------- Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
af74bfd
to
45cec8f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for contributing.
PR to use
execute_model
in the model warmupv1.worker.SpyreWorker.compile_or_warm_up_model
instead of a separate dummy forward pass.Closes #12