Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,27 +459,6 @@ def pad_input_ids(

return input_ids, position_ids, mask

def _raw_model_forward(
self,
input_ids: torch.Tensor,
mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value_states: Optional[List[Tuple[torch.Tensor,
torch.Tensor]]] = None,
use_cache: bool = False,
only_last_token: bool = False,
attn_algorithm: Optional[str] = None
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor,
torch.Tensor]]]]:

return self.model.model(input_ids,
mask=mask,
position_ids=position_ids,
past_key_value_states=past_key_value_states,
use_cache=use_cache,
only_last_token=only_last_token,
attn_algorithm=attn_algorithm)

def get_kv_cache_spec(self) -> KVCacheSpec:
"""
This method should generate the KVCache spec by parsing the kv cache
Expand Down
151 changes: 80 additions & 71 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.v1.core.scheduler import SchedulerOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase as WorkerBaseV1
Expand Down Expand Up @@ -233,12 +235,7 @@ def load_model(self):

def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
special_token_ids, batch_size):
# TODO See if we can use `self.execute_model` instead for the warmup
# It's slightly risky to implement different forward pass logic here,
# which can go out of sync with the real forward pass and cause problems
# for torch.compile

# warmup the model
warmup_start_t = time.time()
# NOTE(ngl): empty tensor causes spyre to hang, so using
# randint without 0 and the eos and bos token
Expand All @@ -254,6 +251,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
valid_token_ids_tensor = torch.tensor(valid_token_ids,
dtype=torch.long,
device=torch.device("cpu"))

# Sample from the valid token ids
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (batch_size, prompt_len))]
Expand All @@ -266,82 +264,93 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
# impl when padding too much
extra_kwargs["attn_algorithm"] = "math"

print(f"[SpyreWorker] warmup for prompt length "
f"{prompt_len} and max output tokens {num_decode_tokens}.")
# Set up dummy requests for prefill steps
dummy_requests = [
NewRequestData(
req_id="warmup",
prompt_token_ids=warmup_tokens_tensor[i].tolist(),
prompt="test",
mm_inputs=[],
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(max_tokens=num_decode_tokens),
block_ids=[0],
num_computed_tokens=0,
lora_request=None,
) for i in range(batch_size)
]

# Set up dummy cached_requests for decode steps
cached_requests = [
CachedRequestData(
req_id=req.req_id,
resumed_from_preemption=False,
new_token_ids=[
valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (1, )).item()]
], # placeholder token
new_block_ids=req.block_ids,
num_computed_tokens=req.num_computed_tokens,
) for req in dummy_requests
]

# 1. trace
print("[SpyreWorker] warmup 1/2...")
# TODO: torch_sendnn.CleanGraph() should be necessary?
# warmup 1st forward pass
self._warmup_model_forward_pass(warmup_tokens_tensor,
valid_token_ids_tensor, prompt_len,
num_decode_tokens, batch_size,
extra_kwargs)
# Set up scheduler_output for execute_model
scheduler_output = SchedulerOutput(
scheduled_new_reqs=dummy_requests,
scheduled_cached_reqs=[],
num_scheduled_tokens={i: prompt_len
for i in range(batch_size)},
total_num_scheduled_tokens=sum(prompt_len
for _ in range(batch_size)),
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
)

# First full forward pass
logger.info("Warmup forward pass 1/2...")
self._warmup_model_forward_pass(scheduler_output, dummy_requests,
cached_requests, num_decode_tokens)

# 2. compile
print("[SpyreWorker] warmup 2/2...")
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
from torch_sendnn import torch_sendnn
ul_start_time = time.time()
torch_sendnn.update_lazyhandle()
ul_stop_time = time.time()
ul_total_t = ul_stop_time - ul_start_time
print(f"update_lazyhandle() done (duration: {ul_total_t}s)")
logger.info("update_lazyhandle() done (duration: %.3fs)",
ul_stop_time - ul_start_time)

# warmup 2nd forward pass
self._warmup_model_forward_pass(warmup_tokens_tensor,
valid_token_ids_tensor, prompt_len,
num_decode_tokens, batch_size,
extra_kwargs)
# Second full forward pass
logger.info("Warmup forward pass 2/2...")
self._warmup_model_forward_pass(scheduler_output, dummy_requests,
cached_requests, num_decode_tokens)

warmup_end_t = time.time()
warmup_total_t = warmup_end_t - warmup_start_t
print("[SpyreWorker] ... warmup finished.")
print(f"\twarmup took {warmup_total_t}s (for prompt length"
f"{prompt_len} and max output tokens {num_decode_tokens})")

def _warmup_model_forward_pass(self, warmup_tokens_tensor,
valid_token_ids_tensor, prompt_len,
num_decode_tokens, batch_size,
extra_kwargs):
# padding warmup tokens to obtain the
# corresponding position ids and mask
warmup_tokens_pad, self.model_runner._position_ids, \
self.model_runner._mask = self.model_runner.pad_input_ids(
warmup_tokens_tensor, min_pad_length=prompt_len)

logits, past_key_value_states = self.model_runner._raw_model_forward(
warmup_tokens_pad,
position_ids=self.model_runner._position_ids,
mask=self.model_runner._mask,
past_key_value_states=None,
use_cache=True,
only_last_token=True,
**extra_kwargs)
# decoding
for i in range(num_decode_tokens - 1):
# sampling next input token from vocab without bos and eos tokens
decode_tokens = valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (batch_size, 1))]

# update mask and position_ids
self.model_runner._update_mask()
self.model_runner._update_position_ids()

if past_key_value_states is not None:
for layer in past_key_value_states:
for tensor in layer:
torch._dynamo.mark_dynamic(tensor, 2)

logits, past_key_value_states = self.model_runner.\
_raw_model_forward(
decode_tokens,
position_ids=self.model_runner._position_ids,
mask=self.model_runner._mask,
past_key_value_states=past_key_value_states,
use_cache=True,
only_last_token=True,
**extra_kwargs)
logger.info("Warmup finished.")
logger.info(
"Warmup took %.3fs (for prompt length %d and max output tokens %d)",
warmup_total_t, prompt_len, num_decode_tokens)

def _warmup_model_forward_pass(
self,
scheduler_output: SchedulerOutput,
requests: List[NewRequestData],
cached_requests: List[CachedRequestData],
num_decode_tokens,
):
"""Handle a complete forward pass"""
scheduler_output.scheduled_new_reqs = requests
scheduler_output.scheduled_cached_reqs = []
self.execute_model(scheduler_output) # Prefill

# Switch to cached requests to trigger decoding steps
scheduler_output.scheduled_new_reqs = []
scheduler_output.scheduled_cached_reqs = cached_requests
for _ in range(num_decode_tokens - 1):
self.execute_model(scheduler_output)

@property
def do_metadata_broadcast(self) -> bool:
Expand Down