Skip to content

Commit b9c03b6

Browse files
committed
Move function back
Signed-off-by: Rafael Vasquez <[email protected]>
1 parent 972fc87 commit b9c03b6

File tree

1 file changed

+111
-111
lines changed

1 file changed

+111
-111
lines changed

vllm_spyre/v1/worker/spyre_worker.py

Lines changed: 111 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -82,117 +82,6 @@ def compile_or_warm_up_model(self) -> None:
8282
"combinations finished. Total warmup time %.3fs.",
8383
len(wup_new_tokens), all_warmup_total_t)
8484

85-
def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
86-
special_token_ids, batch_size):
87-
88-
warmup_start_t = time.time()
89-
# NOTE(ngl): empty tensor causes spyre to hang, so using
90-
# randint without 0 and the eos and bos token
91-
92-
# Create a list of valid values between 1 (inclusive) and vocab
93-
# size (exclusive) by excluding the eos and bos token ids
94-
# (in special_token_ids)
95-
vocab_size = self.model_runner.vocab_size
96-
valid_token_ids = [
97-
i for i in range(1, vocab_size) if i not in set(special_token_ids)
98-
]
99-
# Convert to tensor for sampling
100-
valid_token_ids_tensor = torch.tensor(valid_token_ids,
101-
dtype=torch.long,
102-
device="cpu")
103-
104-
# Sample from the valid token ids
105-
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
106-
0, len(valid_token_ids_tensor), (batch_size, prompt_len))]
107-
108-
# Create requests to be used for prefill steps
109-
dummy_requests = [
110-
NewRequestData(
111-
req_id="warmup",
112-
prompt_token_ids=warmup_tokens_tensor[i].tolist(),
113-
prompt="test",
114-
mm_inputs=[],
115-
mm_hashes=[],
116-
mm_positions=[],
117-
sampling_params=SamplingParams(max_tokens=num_decode_tokens),
118-
block_ids=[0],
119-
num_computed_tokens=0,
120-
lora_request=None,
121-
) for i in range(batch_size)
122-
]
123-
124-
# Set up dummy cached_requests to be used for decode steps
125-
cached_requests = [
126-
CachedRequestData(
127-
req_id=req.req_id,
128-
resumed_from_preemption=False,
129-
new_token_ids=[
130-
valid_token_ids_tensor[torch.randint(
131-
0, len(valid_token_ids_tensor), (1, )).item()]
132-
], # placeholder token
133-
new_block_ids=req.block_ids,
134-
num_computed_tokens=req.num_computed_tokens,
135-
) for req in dummy_requests
136-
]
137-
138-
# To be used for execute_model, start with scheduled_new_reqs
139-
# for prefill
140-
scheduler_output = SchedulerOutput(
141-
scheduled_new_reqs=dummy_requests,
142-
scheduled_cached_reqs=[],
143-
num_scheduled_tokens={i: prompt_len
144-
for i in range(batch_size)},
145-
total_num_scheduled_tokens=sum(prompt_len
146-
for _ in range(batch_size)),
147-
scheduled_spec_decode_tokens={},
148-
scheduled_encoder_inputs={},
149-
num_common_prefix_blocks=0,
150-
finished_req_ids=set(),
151-
free_encoder_input_ids=[],
152-
)
153-
154-
# First full forward pass
155-
logger.info("Warmup 1/2: Prefill...")
156-
self.execute_model(scheduler_output) # Prefill step
157-
158-
# Switch to cached requests to trigger decoding steps
159-
scheduler_output.scheduled_new_reqs = []
160-
scheduler_output.scheduled_cached_reqs = cached_requests
161-
162-
logger.info("Warmup 1/2: Decoding...")
163-
for _ in range(num_decode_tokens - 1):
164-
self.execute_model(scheduler_output)
165-
166-
# update_lazyhandle
167-
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
168-
from torch_sendnn import torch_sendnn
169-
ul_start_time = time.time()
170-
torch_sendnn.update_lazyhandle()
171-
ul_stop_time = time.time()
172-
logger.info("update_lazyhandle() done (duration: %.3fs",
173-
ul_stop_time - ul_start_time)
174-
175-
# Second full forward pass
176-
logger.info("Warmup 2/2: Prefill step...")
177-
scheduler_output.scheduled_new_reqs = dummy_requests
178-
scheduler_output.scheduled_cached_reqs = []
179-
self.execute_model(scheduler_output)
180-
181-
# Switch to cached requests to trigger decoding steps
182-
scheduler_output.scheduled_new_reqs = []
183-
scheduler_output.scheduled_cached_reqs = cached_requests
184-
185-
logger.info("[Warmup 2/2: Decoding steps...")
186-
for _ in range(num_decode_tokens - 1):
187-
self.execute_model(scheduler_output)
188-
189-
warmup_end_t = time.time()
190-
warmup_total_t = warmup_end_t - warmup_start_t
191-
logger.info("Warmup finished.")
192-
logger.info(
193-
"Warmup took %.3fs (for prompt length %d and max output tokens %d)",
194-
warmup_total_t, prompt_len, num_decode_tokens)
195-
19685
def check_health(self) -> None:
19786
"""Basic health check (override for device-specific checks)."""
19887
# TODO: Implement something!
@@ -344,6 +233,117 @@ def load_model(self):
344233
load_model_total_t = load_model_end_t - load_model_start_t
345234
logger.info("load model took %.3fs", load_model_total_t)
346235

236+
def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
237+
special_token_ids, batch_size):
238+
239+
warmup_start_t = time.time()
240+
# NOTE(ngl): empty tensor causes spyre to hang, so using
241+
# randint without 0 and the eos and bos token
242+
243+
# Create a list of valid values between 1 (inclusive) and vocab
244+
# size (exclusive) by excluding the eos and bos token ids
245+
# (in special_token_ids)
246+
vocab_size = self.model_runner.vocab_size
247+
valid_token_ids = [
248+
i for i in range(1, vocab_size) if i not in set(special_token_ids)
249+
]
250+
# Convert to tensor for sampling
251+
valid_token_ids_tensor = torch.tensor(valid_token_ids,
252+
dtype=torch.long,
253+
device="cpu")
254+
255+
# Sample from the valid token ids
256+
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
257+
0, len(valid_token_ids_tensor), (batch_size, prompt_len))]
258+
259+
# Create requests to be used for prefill steps
260+
dummy_requests = [
261+
NewRequestData(
262+
req_id="warmup",
263+
prompt_token_ids=warmup_tokens_tensor[i].tolist(),
264+
prompt="test",
265+
mm_inputs=[],
266+
mm_hashes=[],
267+
mm_positions=[],
268+
sampling_params=SamplingParams(max_tokens=num_decode_tokens),
269+
block_ids=[0],
270+
num_computed_tokens=0,
271+
lora_request=None,
272+
) for i in range(batch_size)
273+
]
274+
275+
# Set up dummy cached_requests to be used for decode steps
276+
cached_requests = [
277+
CachedRequestData(
278+
req_id=req.req_id,
279+
resumed_from_preemption=False,
280+
new_token_ids=[
281+
valid_token_ids_tensor[torch.randint(
282+
0, len(valid_token_ids_tensor), (1, )).item()]
283+
], # placeholder token
284+
new_block_ids=req.block_ids,
285+
num_computed_tokens=req.num_computed_tokens,
286+
) for req in dummy_requests
287+
]
288+
289+
# To be used for execute_model, start with scheduled_new_reqs
290+
# for prefill
291+
scheduler_output = SchedulerOutput(
292+
scheduled_new_reqs=dummy_requests,
293+
scheduled_cached_reqs=[],
294+
num_scheduled_tokens={i: prompt_len
295+
for i in range(batch_size)},
296+
total_num_scheduled_tokens=sum(prompt_len
297+
for _ in range(batch_size)),
298+
scheduled_spec_decode_tokens={},
299+
scheduled_encoder_inputs={},
300+
num_common_prefix_blocks=0,
301+
finished_req_ids=set(),
302+
free_encoder_input_ids=[],
303+
)
304+
305+
# First full forward pass
306+
logger.info("Warmup 1/2: Prefill...")
307+
self.execute_model(scheduler_output) # Prefill step
308+
309+
# Switch to cached requests to trigger decoding steps
310+
scheduler_output.scheduled_new_reqs = []
311+
scheduler_output.scheduled_cached_reqs = cached_requests
312+
313+
logger.info("Warmup 1/2: Decoding...")
314+
for _ in range(num_decode_tokens - 1):
315+
self.execute_model(scheduler_output)
316+
317+
# update_lazyhandle
318+
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
319+
from torch_sendnn import torch_sendnn
320+
ul_start_time = time.time()
321+
torch_sendnn.update_lazyhandle()
322+
ul_stop_time = time.time()
323+
logger.info("update_lazyhandle() done (duration: %.3fs",
324+
ul_stop_time - ul_start_time)
325+
326+
# Second full forward pass
327+
logger.info("Warmup 2/2: Prefill step...")
328+
scheduler_output.scheduled_new_reqs = dummy_requests
329+
scheduler_output.scheduled_cached_reqs = []
330+
self.execute_model(scheduler_output)
331+
332+
# Switch to cached requests to trigger decoding steps
333+
scheduler_output.scheduled_new_reqs = []
334+
scheduler_output.scheduled_cached_reqs = cached_requests
335+
336+
logger.info("[Warmup 2/2: Decoding steps...")
337+
for _ in range(num_decode_tokens - 1):
338+
self.execute_model(scheduler_output)
339+
340+
warmup_end_t = time.time()
341+
warmup_total_t = warmup_end_t - warmup_start_t
342+
logger.info("Warmup finished.")
343+
logger.info(
344+
"Warmup took %.3fs (for prompt length %d and max output tokens %d)",
345+
warmup_total_t, prompt_len, num_decode_tokens)
346+
347347
@property
348348
def do_metadata_broadcast(self) -> bool:
349349
return True

0 commit comments

Comments
 (0)