Skip to content

Commit 45cec8f

Browse files
committed
Refactor forward_pass, update comments/logs for clarity
Signed-off-by: Rafael Vasquez <[email protected]>
1 parent a6b8e25 commit 45cec8f

File tree

1 file changed

+36
-28
lines changed

1 file changed

+36
-28
lines changed

vllm_spyre/v1/worker/spyre_worker.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -250,13 +250,21 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
250250
# Convert to tensor for sampling
251251
valid_token_ids_tensor = torch.tensor(valid_token_ids,
252252
dtype=torch.long,
253-
device="cpu")
253+
device=torch.device("cpu"))
254254

255255
# Sample from the valid token ids
256256
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
257257
0, len(valid_token_ids_tensor), (batch_size, prompt_len))]
258258

259-
# Create requests to be used for prefill steps
259+
extra_kwargs = {}
260+
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND not in [
261+
"sendnn", "sendnn_decoder"
262+
]:
263+
# Bug in 2.3.1 fixed in 2.4.1 for SDPA flash cpu
264+
# impl when padding too much
265+
extra_kwargs["attn_algorithm"] = "math"
266+
267+
# Set up dummy requests for prefill steps
260268
dummy_requests = [
261269
NewRequestData(
262270
req_id="warmup",
@@ -272,7 +280,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
272280
) for i in range(batch_size)
273281
]
274282

275-
# Set up dummy cached_requests to be used for decode steps
283+
# Set up dummy cached_requests for decode steps
276284
cached_requests = [
277285
CachedRequestData(
278286
req_id=req.req_id,
@@ -286,8 +294,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
286294
) for req in dummy_requests
287295
]
288296

289-
# To be used for execute_model, start with scheduled_new_reqs
290-
# for prefill
297+
# Set up scheduler_output for execute_model
291298
scheduler_output = SchedulerOutput(
292299
scheduled_new_reqs=dummy_requests,
293300
scheduled_cached_reqs=[],
@@ -303,18 +310,10 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
303310
)
304311

305312
# First full forward pass
306-
logger.info("Warmup 1/2: Prefill...")
307-
self.execute_model(scheduler_output) # Prefill step
308-
309-
# Switch to cached requests to trigger decoding steps
310-
scheduler_output.scheduled_new_reqs = []
311-
scheduler_output.scheduled_cached_reqs = cached_requests
313+
logger.info("Warmup forward pass 1/2...")
314+
self._warmup_model_forward_pass(scheduler_output, dummy_requests,
315+
cached_requests, num_decode_tokens)
312316

313-
logger.info("Warmup 1/2: Decoding...")
314-
for _ in range(num_decode_tokens - 1):
315-
self.execute_model(scheduler_output)
316-
317-
# update_lazyhandle
318317
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
319318
from torch_sendnn import torch_sendnn
320319
ul_start_time = time.time()
@@ -324,18 +323,9 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
324323
ul_stop_time - ul_start_time)
325324

326325
# Second full forward pass
327-
logger.info("Warmup 2/2: Prefill step...")
328-
scheduler_output.scheduled_new_reqs = dummy_requests
329-
scheduler_output.scheduled_cached_reqs = []
330-
self.execute_model(scheduler_output)
331-
332-
# Switch to cached requests to trigger decoding steps
333-
scheduler_output.scheduled_new_reqs = []
334-
scheduler_output.scheduled_cached_reqs = cached_requests
335-
336-
logger.info("Warmup 2/2: Decoding steps...")
337-
for _ in range(num_decode_tokens - 1):
338-
self.execute_model(scheduler_output)
326+
logger.info("Warmup forward pass 2/2...")
327+
self._warmup_model_forward_pass(scheduler_output, dummy_requests,
328+
cached_requests, num_decode_tokens)
339329

340330
warmup_end_t = time.time()
341331
warmup_total_t = warmup_end_t - warmup_start_t
@@ -344,6 +334,24 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
344334
"Warmup took %.3fs (for prompt length %d and max output tokens %d)",
345335
warmup_total_t, prompt_len, num_decode_tokens)
346336

337+
def _warmup_model_forward_pass(
338+
self,
339+
scheduler_output: SchedulerOutput,
340+
requests: List[NewRequestData],
341+
cached_requests: List[CachedRequestData],
342+
num_decode_tokens,
343+
):
344+
"""Handle a complete forward pass"""
345+
scheduler_output.scheduled_new_reqs = requests
346+
scheduler_output.scheduled_cached_reqs = []
347+
self.execute_model(scheduler_output) # Prefill
348+
349+
# Switch to cached requests to trigger decoding steps
350+
scheduler_output.scheduled_new_reqs = []
351+
scheduler_output.scheduled_cached_reqs = cached_requests
352+
for _ in range(num_decode_tokens - 1):
353+
self.execute_model(scheduler_output)
354+
347355
@property
348356
def do_metadata_broadcast(self) -> bool:
349357
return True

0 commit comments

Comments
 (0)