diff --git a/examples/offline_inference/long_context.py b/examples/offline_inference/long_context.py index 019d9a200..e52c56221 100644 --- a/examples/offline_inference/long_context.py +++ b/examples/offline_inference/long_context.py @@ -56,7 +56,7 @@ trunc = args.trunc_print_len max_num_seqs = args.max_num_seqs # defines the max batch size -assert args.max_prompt_len < args.max_model_len +assert args.max_prompt_len <= args.max_model_len if platform.machine() == "arm64": print("Detected arm64 running environment. " @@ -122,7 +122,7 @@ def round_up(t): tokens_to_generate = [ - args.max_model_len - round_up(plen) for plen in prompt_lens + args.max_model_len + 1 - round_up(prompt_len) for prompt_len in prompt_lens ] sampling_params = [ diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 20007d501..f8e22d1fa 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -284,7 +284,8 @@ def validate_request( # ceil division to pad to next block boundary prompt_padding_len = math.ceil( prompt_len / cls._block_size) * cls._block_size - if (prompt_padding_len + max_tokens + # we have to account for the token generated during prefill (-1) + if (prompt_padding_len + max_tokens - 1 > cls._config.scheduler_config.max_model_len): raise ValueError( "Could not add request: prompt length is "