Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,24 +442,3 @@ def pad_input_ids(
position_ids = torch.stack(position_ids_list)

return input_ids, position_ids, mask

def _raw_model_forward(
self,
input_ids: torch.Tensor,
mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value_states: Optional[List[Tuple[torch.Tensor,
torch.Tensor]]] = None,
use_cache: bool = False,
only_last_token: bool = False,
attn_algorithm: Optional[str] = None
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor,
torch.Tensor]]]]:

return self.model.model(input_ids,
mask=mask,
position_ids=position_ids,
past_key_value_states=past_key_value_states,
use_cache=use_cache,
only_last_token=only_last_token,
attn_algorithm=attn_algorithm)
162 changes: 47 additions & 115 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
init_distributed_environment)
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import NewRequestData, SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
Expand Down Expand Up @@ -68,8 +69,51 @@ def compile_or_warm_up_model(self) -> None:
print(f"[SpyreWorker] Warming up for prompt length {prompt_len}, "
f"decoding {num_decode_tokens} tokens with batch "
f"size {batch_size}")
self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally, I like the idea of a helper function here to make things more readable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, specifically avoiding double-nested for loops is nice to do

self.restricted_tokens, batch_size)

num_scheduled_tokens: dict = {}
total_num_scheduled_tokens: int = 0
dummy_requests: list = []
for i in range(batch_size):
dummy_requests.append(
NewRequestData(
req_id=f"warmup-{i}",
prompt_token_ids=[1] * prompt_len,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We previously sampled these tokens from a list of valid tokens:

valid_token_ids = [ i for i in range(1, vocab_size) if i not in set(special_token_ids)]

where special_token_ids contains BOS, EOS and pad token ids. Not sure whether this was needed, or what happens if in your case any of the above special token ids are 1.

Copy link
Collaborator

@joerunde joerunde Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, for the very first dummy model I checked this is true: https://huggingface.co/JackFram/llama-160m/blob/main/config.json#L6

I think vllm uses a bunch of repeated token id 0 for profiling, since the input ids tensor is just initialized with torch.zeros and for text-only models it's not updated for profiling.

The general idea with setting a repeated token ID was to have the model continue the sequence, so it doesn't end up hitting an eos token early and stopping. But if we control the loop here that keeps invoking the model, maybe that doesn't matter.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I also suspect it does not matter as we force the decode steps in the loop, but rather save than sorry:)

prompt="test",
mm_inputs=[],
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(
max_tokens=num_decode_tokens),
block_ids=[0],
num_computed_tokens=0,
lora_request=None,
))
num_scheduled_tokens[i] = prompt_len
total_num_scheduled_tokens += num_scheduled_tokens[i]

scheduler_output = SchedulerOutput(
scheduled_new_reqs=dummy_requests,
scheduled_cached_reqs=[],
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
)

# Use execute_model for warm up
self.execute_model(scheduler_output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand this executes the model only for one step. In your case it does the prefill and generates the 1st token. We need to warm up not only for prefill, but also for (num_decode_tokens - 1) decode steps (since prefill already produced 1 token).


if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
from torch_sendnn import torch_sendnn
ul_start_time = time.time()
torch_sendnn.update_lazyhandle()
ul_stop_time = time.time()
ul_total_t = ul_stop_time - ul_start_time
print(f"update_lazyhandle() done (duration: {ul_total_t}s)")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After torch_sendnn.update_lazyhandle() there is a second complete warmup needed.

To sum up:

  1. complete forward pass: prefill plus (num_decode_tokens - 1) decode steps
  2. torch_sendnn.update_lazyhandle()
  3. complete forward pass: prefill plus (num_decode_tokens - 1) decode steps

See comment below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying this, learning as I go but it makes sense to me now.

I took another stab at it, still trying to use execute_model to avoid doing anything manually except the dummy data setup.

all_warmup_end_t = time.time()
all_warmup_total_t = all_warmup_end_t - all_warmup_start_t
print(f"[SpyreWorker] All warmups for "
Expand Down Expand Up @@ -209,118 +253,6 @@ def load_model(self):
load_model_total_t = load_model_end_t - load_model_start_t
print(f"\tload model took {load_model_total_t}s")

def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
special_token_ids, batch_size):
# TODO See if we can use `self.execute_model` instead for the warmup
# It's slightly risky to implement different forward pass logic here,
# which can go out of sync with the real forward pass and cause problems
# for torch.compile

# warmup the model
warmup_start_t = time.time()
# NOTE(ngl): empty tensor causes spyre to hang, so using
# randint without 0 and the eos and bos token

# Create a list of valid values between 1 (inclusive) and vocab
# size (exclusive) by excluding the eos and bos token ids
# (in special_token_ids)
vocab_size = self.model_runner.vocab_size
valid_token_ids = [
i for i in range(1, vocab_size) if i not in set(special_token_ids)
]
# Convert to tensor for sampling
valid_token_ids_tensor = torch.tensor(valid_token_ids,
dtype=torch.long,
device=torch.device("cpu"))
# Sample from the valid token ids
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (batch_size, prompt_len))]

extra_kwargs = {}
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND not in [
"sendnn", "sendnn_decoder"
]:
# Bug in 2.3.1 fixed in 2.4.1 for SDPA flash cpu
# impl when padding too much
extra_kwargs["attn_algorithm"] = "math"

print(f"[SpyreWorker] warmup for prompt length "
f"{prompt_len} and max output tokens {num_decode_tokens}.")

# 1. trace
print("[SpyreWorker] warmup 1/2...")
# TODO: torch_sendnn.CleanGraph() should be necessary?
# warmup 1st forward pass
self._warmup_model_forward_pass(warmup_tokens_tensor,
valid_token_ids_tensor, prompt_len,
num_decode_tokens, batch_size,
extra_kwargs)

# 2. compile
print("[SpyreWorker] warmup 2/2...")
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
from torch_sendnn import torch_sendnn
ul_start_time = time.time()
torch_sendnn.update_lazyhandle()
ul_stop_time = time.time()
ul_total_t = ul_stop_time - ul_start_time
print(f"update_lazyhandle() done (duration: {ul_total_t}s)")

# warmup 2nd forward pass
self._warmup_model_forward_pass(warmup_tokens_tensor,
valid_token_ids_tensor, prompt_len,
num_decode_tokens, batch_size,
extra_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has to happen for each warmup shape (combination of prompt_len, num_decode_tokens, batch_size)


warmup_end_t = time.time()
warmup_total_t = warmup_end_t - warmup_start_t
print("[SpyreWorker] ... warmup finished.")
print(f"\twarmup took {warmup_total_t}s (for prompt length"
f"{prompt_len} and max output tokens {num_decode_tokens})")

def _warmup_model_forward_pass(self, warmup_tokens_tensor,
valid_token_ids_tensor, prompt_len,
num_decode_tokens, batch_size,
extra_kwargs):
# padding warmup tokens to obtain the
# corresponding position ids and mask
warmup_tokens_pad, self.model_runner._position_ids, \
self.model_runner._mask = self.model_runner.pad_input_ids(
warmup_tokens_tensor, min_pad_length=prompt_len)

logits, past_key_value_states = self.model_runner._raw_model_forward(
warmup_tokens_pad,
position_ids=self.model_runner._position_ids,
mask=self.model_runner._mask,
past_key_value_states=None,
use_cache=True,
only_last_token=True,
**extra_kwargs)
# decoding
for i in range(num_decode_tokens - 1):
# sampling next input token from vocab without bos and eos tokens
decode_tokens = valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (batch_size, 1))]

# update mask and position_ids
self.model_runner._update_mask()
self.model_runner._update_position_ids()

if past_key_value_states is not None:
for layer in past_key_value_states:
for tensor in layer:
torch._dynamo.mark_dynamic(tensor, 2)

logits, past_key_value_states = self.model_runner.\
_raw_model_forward(
decode_tokens,
position_ids=self.model_runner._position_ids,
mask=self.model_runner._mask,
past_key_value_states=past_key_value_states,
use_cache=True,
only_last_token=True,
**extra_kwargs)

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.

Expand Down