Skip to content

Commit 972fc87

Browse files
committed
Lints, fixes prints
Signed-off-by: Rafael Vasquez <[email protected]>
1 parent aca9c1f commit 972fc87

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,6 @@ def pad_input_ids(
437437

438438
return input_ids, position_ids, mask
439439

440-
441440
def get_kv_cache_spec(self) -> KVCacheSpec:
442441
"""
443442
This method should generate the KVCache spec by parsing the kv cache

vllm_spyre/v1/worker/spyre_worker.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from vllm.model_executor import set_random_seed
1616
from vllm.platforms import current_platform
1717
from vllm.sampling_params import SamplingParams
18-
from vllm.v1.core.scheduler import CachedRequestData, NewRequestData, SchedulerOutput
18+
from vllm.v1.core.scheduler import (CachedRequestData, NewRequestData,
19+
SchedulerOutput)
1920
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2021
from vllm.v1.outputs import ModelRunnerOutput
2122
from vllm.v1.worker.worker_base import WorkerBase as WorkerBaseV1
@@ -81,9 +82,8 @@ def compile_or_warm_up_model(self) -> None:
8182
"combinations finished. Total warmup time %.3fs.",
8283
len(wup_new_tokens), all_warmup_total_t)
8384

84-
8585
def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
86-
special_token_ids, batch_size):
86+
special_token_ids, batch_size):
8787

8888
warmup_start_t = time.time()
8989
# NOTE(ngl): empty tensor causes spyre to hang, so using
@@ -97,7 +97,9 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
9797
i for i in range(1, vocab_size) if i not in set(special_token_ids)
9898
]
9999
# Convert to tensor for sampling
100-
valid_token_ids_tensor = torch.tensor(valid_token_ids, dtype=torch.long, device="cpu")
100+
valid_token_ids_tensor = torch.tensor(valid_token_ids,
101+
dtype=torch.long,
102+
device="cpu")
101103

102104
# Sample from the valid token ids
103105
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
@@ -106,10 +108,12 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
106108
# Create requests to be used for prefill steps
107109
dummy_requests = [
108110
NewRequestData(
109-
req_id=f"warmup",
111+
req_id="warmup",
110112
prompt_token_ids=warmup_tokens_tensor[i].tolist(),
111113
prompt="test",
112-
mm_inputs=[], mm_hashes=[], mm_positions=[],
114+
mm_inputs=[],
115+
mm_hashes=[],
116+
mm_positions=[],
113117
sampling_params=SamplingParams(max_tokens=num_decode_tokens),
114118
block_ids=[0],
115119
num_computed_tokens=0,
@@ -122,8 +126,10 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
122126
CachedRequestData(
123127
req_id=req.req_id,
124128
resumed_from_preemption=False,
125-
new_token_ids=[valid_token_ids_tensor[torch.randint(
126-
0, len(valid_token_ids_tensor), (1,)).item()]], # placeholder token
129+
new_token_ids=[
130+
valid_token_ids_tensor[torch.randint(
131+
0, len(valid_token_ids_tensor), (1, )).item()]
132+
], # placeholder token
127133
new_block_ids=req.block_ids,
128134
num_computed_tokens=req.num_computed_tokens,
129135
) for req in dummy_requests
@@ -134,8 +140,10 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
134140
scheduler_output = SchedulerOutput(
135141
scheduled_new_reqs=dummy_requests,
136142
scheduled_cached_reqs=[],
137-
num_scheduled_tokens={i: prompt_len for i in range(batch_size)},
138-
total_num_scheduled_tokens=sum(prompt_len for _ in range(batch_size)),
143+
num_scheduled_tokens={i: prompt_len
144+
for i in range(batch_size)},
145+
total_num_scheduled_tokens=sum(prompt_len
146+
for _ in range(batch_size)),
139147
scheduled_spec_decode_tokens={},
140148
scheduled_encoder_inputs={},
141149
num_common_prefix_blocks=0,
@@ -144,14 +152,14 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
144152
)
145153

146154
# First full forward pass
147-
logger.info("[SpyreWorker] Warmup 1/2: Prefill...")
155+
logger.info("Warmup 1/2: Prefill...")
148156
self.execute_model(scheduler_output) # Prefill step
149157

150158
# Switch to cached requests to trigger decoding steps
151159
scheduler_output.scheduled_new_reqs = []
152160
scheduler_output.scheduled_cached_reqs = cached_requests
153161

154-
logger.info("[SpyreWorker] Warmup 1/2: Decoding...")
162+
logger.info("Warmup 1/2: Decoding...")
155163
for _ in range(num_decode_tokens - 1):
156164
self.execute_model(scheduler_output)
157165

@@ -161,10 +169,11 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
161169
ul_start_time = time.time()
162170
torch_sendnn.update_lazyhandle()
163171
ul_stop_time = time.time()
164-
logger.info(f"update_lazyhandle() done (duration: {ul_stop_time - ul_start_time}s)")
172+
logger.info("update_lazyhandle() done (duration: %.3fs",
173+
ul_stop_time - ul_start_time)
165174

166175
# Second full forward pass
167-
logger.info("[SpyreWorker] Warmup 2/2: Prefill step...")
176+
logger.info("Warmup 2/2: Prefill step...")
168177
scheduler_output.scheduled_new_reqs = dummy_requests
169178
scheduler_output.scheduled_cached_reqs = []
170179
self.execute_model(scheduler_output)
@@ -173,16 +182,16 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
173182
scheduler_output.scheduled_new_reqs = []
174183
scheduler_output.scheduled_cached_reqs = cached_requests
175184

176-
logger.info("[SpyreWorker] Warmup 2/2: Decoding steps...")
185+
logger.info("[Warmup 2/2: Decoding steps...")
177186
for _ in range(num_decode_tokens - 1):
178187
self.execute_model(scheduler_output)
179188

180189
warmup_end_t = time.time()
181190
warmup_total_t = warmup_end_t - warmup_start_t
182-
logger.info("[SpyreWorker] ... warmup finished.")
183-
logger.info(f"\twarmup took {warmup_total_t}s (for prompt length"
184-
f"{prompt_len} and max output tokens {num_decode_tokens})")
185-
191+
logger.info("Warmup finished.")
192+
logger.info(
193+
"Warmup took %.3fs (for prompt length %d and max output tokens %d)",
194+
warmup_total_t, prompt_len, num_decode_tokens)
186195

187196
def check_health(self) -> None:
188197
"""Basic health check (override for device-specific checks)."""

0 commit comments

Comments
 (0)