Skip to content

Commit 5b95b15

Browse files
committed
Store warmup shapes in config
Signed-off-by: Thomas Parnell <[email protected]>
1 parent 9322b33 commit 5b95b15

File tree

4 files changed

+5
-10
lines changed

4 files changed

+5
-10
lines changed

vllm_spyre/core/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ def _schedule_prefills(
669669
seq_groups: List[ScheduledSequenceGroup] = []
670670

671671
# SPYRE SPECIFIC CODE BLOCK START
672-
spyre_warmup_shapes = current_platform.get_warmup_shapes()
672+
spyre_warmup_shapes = self.scheduler_config.spyre_warmup_shapes
673673
applicable_spyre_warmup_shapes = list(spyre_warmup_shapes)
674674
# SPYRE SPECIFIC CODE BLOCK END
675675

vllm_spyre/platform.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ class SpyrePlatform(Platform):
2222
device_name: str = "spyre"
2323
device_type: str = "cpu"
2424
supported_quantization: list[str] = ["gptq"]
25-
spyre_warmup_shapes: tuple[dict[str, int], ...]
2625

2726
@classmethod
2827
def get_device_name(cls, device_id: int = 0) -> str:
@@ -79,7 +78,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
7978
cls.set_warmup_shapes(scheduler_config)
8079
max_batch_size = 0
8180
max_seq_len = 0
82-
for shape in cls.get_warmup_shapes():
81+
for shape in scheduler_config.spyre_warmup_shapes:
8382
max_batch_size = max(max_batch_size, shape['batch_size'])
8483
max_seq_len = max(max_batch_size,
8584
shape['prompt_length'] + shape['new_tokens'])
@@ -153,15 +152,11 @@ def set_warmup_shapes(cls, scheduler_config) -> None:
153152
logger.info("VLLM_SPYRE_WARMUP_NEW_TOKENS = %s", wup_new_tokens)
154153
logger.info("VLLM_SPYRE_WARMUP_BATCH_SIZES = %s", wup_batch_sizes)
155154

156-
cls.spyre_warmup_shapes = tuple(
155+
scheduler_config.spyre_warmup_shapes = tuple(
157156
sorted([{
158157
'prompt_length': pl,
159158
'new_tokens': nt,
160159
'batch_size': bs
161160
} for pl, nt, bs in zip(wup_prompt_lens, wup_new_tokens,
162161
wup_batch_sizes)],
163162
key=operator.itemgetter('batch_size', 'prompt_length')))
164-
165-
@classmethod
166-
def get_warmup_shapes(cls) -> tuple[dict[str, int], ...]:
167-
return cls.spyre_warmup_shapes

vllm_spyre/worker/spyre_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _prepare_prompt(
126126
input_token_list: List[torch.Tensor] = []
127127

128128
# find warmup shape to be used for padding and batching
129-
spyre_warmup_shapes = current_platform.get_warmup_shapes()
129+
spyre_warmup_shapes = self.scheduler_config.spyre_warmup_shapes
130130
applicable_spyre_warmup_shapes = [
131131
shape for shape in spyre_warmup_shapes
132132
if len(seq_group_metadata_list) <= shape['batch_size']

vllm_spyre/worker/spyre_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def load_model(self):
146146
# for all requested model warmups
147147
# printing env variables for debugging purposes
148148
load_model_start_t = time.time()
149-
spyre_warmup_shapes = current_platform.get_warmup_shapes()
149+
spyre_warmup_shapes = self.vllm_config.scheduler_config.spyre_warmup_shapes
150150
wup_prompt_lens, wup_new_tokens = zip(*[(s["prompt_length"],
151151
s["new_tokens"])
152152
for s in spyre_warmup_shapes])

0 commit comments

Comments
 (0)