Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/offline_inference/long_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
trunc = args.trunc_print_len

max_num_seqs = args.max_num_seqs # defines the max batch size
assert args.max_prompt_len < args.max_model_len
assert args.max_prompt_len <= args.max_model_len

if platform.machine() == "arm64":
print("Detected arm64 running environment. "
Expand Down Expand Up @@ -122,7 +122,7 @@ def round_up(t):


tokens_to_generate = [
args.max_model_len - round_up(plen) for plen in prompt_lens
args.max_model_len + 1 - round_up(plen) for plen in prompt_lens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot help but read plen as if it was one word, can we rename that?

Suggested change
args.max_model_len + 1 - round_up(plen) for plen in prompt_lens
args.max_model_len + 1 - round_up(p_len) for p_len in prompt_lens

or

Suggested change
args.max_model_len + 1 - round_up(plen) for plen in prompt_lens
args.max_model_len + 1 - round_up(prompt_len) for prompt_lenin prompt_lens

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for prompt_len, we don't need to save the bytes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prompt_lenin

]

sampling_params = [
Expand Down
3 changes: 2 additions & 1 deletion vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ def validate_request(
# ceil division to pad to next block boundary
prompt_padding_len = math.ceil(
prompt_len / cls._block_size) * cls._block_size
if (prompt_padding_len + max_tokens
# we have to account for the token generated during prefill (-1)
if (prompt_padding_len + max_tokens - 1
> cls._config.scheduler_config.max_model_len):
raise ValueError(
"Could not add request: prompt length is "
Expand Down