Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,27 +437,6 @@ def pad_input_ids(

return input_ids, position_ids, mask

def _raw_model_forward(
self,
input_ids: torch.Tensor,
mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value_states: Optional[List[Tuple[torch.Tensor,
torch.Tensor]]] = None,
use_cache: bool = False,
only_last_token: bool = False,
attn_algorithm: Optional[str] = None
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor,
torch.Tensor]]]]:

return self.model.model(input_ids,
mask=mask,
position_ids=position_ids,
past_key_value_states=past_key_value_states,
use_cache=use_cache,
only_last_token=only_last_token,
attn_algorithm=attn_algorithm)

def get_kv_cache_spec(self) -> KVCacheSpec:
"""
This method should generate the KVCache spec by parsing the kv cache
Expand Down
165 changes: 83 additions & 82 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase as WorkerBaseV1
Expand Down Expand Up @@ -233,12 +235,7 @@ def load_model(self):

def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
special_token_ids, batch_size):
# TODO See if we can use `self.execute_model` instead for the warmup
# It's slightly risky to implement different forward pass logic here,
# which can go out of sync with the real forward pass and cause problems
# for torch.compile

# warmup the model
warmup_start_t = time.time()
# NOTE(ngl): empty tensor causes spyre to hang, so using
# randint without 0 and the eos and bos token
Expand All @@ -253,95 +250,99 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
# Convert to tensor for sampling
valid_token_ids_tensor = torch.tensor(valid_token_ids,
dtype=torch.long,
device=torch.device("cpu"))
device="cpu")

# Sample from the valid token ids
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (batch_size, prompt_len))]

extra_kwargs = {}
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND not in [
"sendnn", "sendnn_decoder"
]:
# Bug in 2.3.1 fixed in 2.4.1 for SDPA flash cpu
# impl when padding too much
extra_kwargs["attn_algorithm"] = "math"

print(f"[SpyreWorker] warmup for prompt length "
f"{prompt_len} and max output tokens {num_decode_tokens}.")

# 1. trace
print("[SpyreWorker] warmup 1/2...")
# TODO: torch_sendnn.CleanGraph() should be necessary?
# warmup 1st forward pass
self._warmup_model_forward_pass(warmup_tokens_tensor,
valid_token_ids_tensor, prompt_len,
num_decode_tokens, batch_size,
extra_kwargs)

# 2. compile
print("[SpyreWorker] warmup 2/2...")
# Create requests to be used for prefill steps
dummy_requests = [
NewRequestData(
req_id="warmup",
prompt_token_ids=warmup_tokens_tensor[i].tolist(),
prompt="test",
mm_inputs=[],
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(max_tokens=num_decode_tokens),
block_ids=[0],
num_computed_tokens=0,
lora_request=None,
) for i in range(batch_size)
]

# Set up dummy cached_requests to be used for decode steps
cached_requests = [
CachedRequestData(
req_id=req.req_id,
resumed_from_preemption=False,
new_token_ids=[
valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (1, )).item()]
], # placeholder token
new_block_ids=req.block_ids,
num_computed_tokens=req.num_computed_tokens,
) for req in dummy_requests
]

# To be used for execute_model, start with scheduled_new_reqs
# for prefill
scheduler_output = SchedulerOutput(
scheduled_new_reqs=dummy_requests,
scheduled_cached_reqs=[],
num_scheduled_tokens={i: prompt_len
for i in range(batch_size)},
total_num_scheduled_tokens=sum(prompt_len
for _ in range(batch_size)),
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
)

# First full forward pass
logger.info("Warmup 1/2: Prefill...")
self.execute_model(scheduler_output) # Prefill step

# Switch to cached requests to trigger decoding steps
scheduler_output.scheduled_new_reqs = []
scheduler_output.scheduled_cached_reqs = cached_requests

logger.info("Warmup 1/2: Decoding...")
for _ in range(num_decode_tokens - 1):
self.execute_model(scheduler_output)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I am for the use of helper functions wherever they help reducing code duplication. I am aware its just 10 lines here, but they could be eliminated by reusing the _warmup_model_forward_pass we introduced in the original implementation.

Copy link
Collaborator Author

@rafvasq rafvasq Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I re-introduced _warmup_model_forward_pass to handle the duplicate pass code

# update_lazyhandle
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
from torch_sendnn import torch_sendnn
ul_start_time = time.time()
torch_sendnn.update_lazyhandle()
ul_stop_time = time.time()
ul_total_t = ul_stop_time - ul_start_time
print(f"update_lazyhandle() done (duration: {ul_total_t}s)")
logger.info("update_lazyhandle() done (duration: %.3fs",
ul_stop_time - ul_start_time)

# Second full forward pass
logger.info("Warmup 2/2: Prefill step...")
scheduler_output.scheduled_new_reqs = dummy_requests
scheduler_output.scheduled_cached_reqs = []
self.execute_model(scheduler_output)

# warmup 2nd forward pass
self._warmup_model_forward_pass(warmup_tokens_tensor,
valid_token_ids_tensor, prompt_len,
num_decode_tokens, batch_size,
extra_kwargs)
# Switch to cached requests to trigger decoding steps
scheduler_output.scheduled_new_reqs = []
scheduler_output.scheduled_cached_reqs = cached_requests

logger.info("[Warmup 2/2: Decoding steps...")
for _ in range(num_decode_tokens - 1):
self.execute_model(scheduler_output)

warmup_end_t = time.time()
warmup_total_t = warmup_end_t - warmup_start_t
print("[SpyreWorker] ... warmup finished.")
print(f"\twarmup took {warmup_total_t}s (for prompt length"
f"{prompt_len} and max output tokens {num_decode_tokens})")

def _warmup_model_forward_pass(self, warmup_tokens_tensor,
valid_token_ids_tensor, prompt_len,
num_decode_tokens, batch_size,
extra_kwargs):
# padding warmup tokens to obtain the
# corresponding position ids and mask
warmup_tokens_pad, self.model_runner._position_ids, \
self.model_runner._mask = self.model_runner.pad_input_ids(
warmup_tokens_tensor, min_pad_length=prompt_len)

logits, past_key_value_states = self.model_runner._raw_model_forward(
warmup_tokens_pad,
position_ids=self.model_runner._position_ids,
mask=self.model_runner._mask,
past_key_value_states=None,
use_cache=True,
only_last_token=True,
**extra_kwargs)
# decoding
for i in range(num_decode_tokens - 1):
# sampling next input token from vocab without bos and eos tokens
decode_tokens = valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (batch_size, 1))]

# update mask and position_ids
self.model_runner._update_mask()
self.model_runner._update_position_ids()

if past_key_value_states is not None:
for layer in past_key_value_states:
for tensor in layer:
torch._dynamo.mark_dynamic(tensor, 2)

logits, past_key_value_states = self.model_runner.\
_raw_model_forward(
decode_tokens,
position_ids=self.model_runner._position_ids,
mask=self.model_runner._mask,
past_key_value_states=past_key_value_states,
use_cache=True,
only_last_token=True,
**extra_kwargs)
logger.info("Warmup finished.")
logger.info(
"Warmup took %.3fs (for prompt length %d and max output tokens %d)",
warmup_total_t, prompt_len, num_decode_tokens)

@property
def do_metadata_broadcast(self) -> bool:
Expand Down