diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index c22ed107b..61c133c4b 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -77,7 +77,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # Override --max-num-seqs to the biggest warmup batch size # And override --max-model-len to the biggest warmup sequence cls._warmup_shapes = None - spyre_warmup_shapes = cls.get_warmup_shapes(scheduler_config) + max_model_len = model_config.max_model_len \ + if model_config is not None else sys.maxsize + spyre_warmup_shapes = cls.get_warmup_shapes( + scheduler_config, max_model_len) max_batch_size = 0 max_seq_len = 0 for shape in spyre_warmup_shapes: @@ -168,7 +171,10 @@ def inference_mode(cls): return torch.no_grad() @classmethod - def get_warmup_shapes(cls, scheduler_config) -> tuple[dict[str, int], ...]: + def get_warmup_shapes( + cls, + scheduler_config, + max_model_len: int = sys.maxsize) -> tuple[dict[str, int], ...]: if cls._warmup_shapes is not None: return cls._warmup_shapes # load warmup shapes and sort by "speed" @@ -204,6 +210,16 @@ def get_warmup_shapes(cls, scheduler_config) -> tuple[dict[str, int], ...]: } for pl, nt, bs in zip(wup_prompt_lens, wup_new_tokens, wup_batch_sizes)], key=operator.itemgetter('batch_size', 'prompt_length'))) + + for shape in cls._warmup_shapes: + max_seq_len = shape["prompt_length"] + shape["new_tokens"] + if max_seq_len > max_model_len: + raise RuntimeError( + f"Warmup shape [{shape['batch_size']}," + " {shape['prompt_length']}, {shape['new_tokens']}]" + " results in a maximum sequence length of " + "{max_seq_len} which is longer that what the model " + "supports ({max_model_len})") return cls._warmup_shapes @classmethod