Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# Override --max-num-seqs to the biggest warmup batch size
# And override --max-model-len to the biggest warmup sequence
cls._warmup_shapes = None
spyre_warmup_shapes = cls.get_warmup_shapes(scheduler_config)
max_model_len = model_config.max_model_len \
if model_config is not None else sys.maxsize
spyre_warmup_shapes = cls.get_warmup_shapes(
scheduler_config, max_model_len)
max_batch_size = 0
max_seq_len = 0
for shape in spyre_warmup_shapes:
Expand Down Expand Up @@ -168,7 +171,10 @@ def inference_mode(cls):
return torch.no_grad()

@classmethod
def get_warmup_shapes(cls, scheduler_config) -> tuple[dict[str, int], ...]:
def get_warmup_shapes(
cls,
scheduler_config,
max_model_len: int = sys.maxsize) -> tuple[dict[str, int], ...]:
if cls._warmup_shapes is not None:
return cls._warmup_shapes
# load warmup shapes and sort by "speed"
Expand Down Expand Up @@ -204,6 +210,16 @@ def get_warmup_shapes(cls, scheduler_config) -> tuple[dict[str, int], ...]:
} for pl, nt, bs in zip(wup_prompt_lens, wup_new_tokens,
wup_batch_sizes)],
key=operator.itemgetter('batch_size', 'prompt_length')))

for shape in cls._warmup_shapes:
max_seq_len = shape["prompt_length"] + shape["new_tokens"]
if max_seq_len > max_model_len:
raise RuntimeError(
f"Warmup shape [{shape['batch_size']},"
" {shape['prompt_length']}, {shape['new_tokens']}]"
" results in a maximum sequence length of "
"{max_seq_len} which is longer that what the model "
"supports ({max_model_len})")
return cls._warmup_shapes

@classmethod
Expand Down