Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions examples/offline_inference_spyre_cb_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import time

from vllm import LLM, SamplingParams

max_tokens1 = 10
max_tokens2 = 5
max_tokens3 = 7
max_tokens = max([max_tokens1, max_tokens2, max_tokens3])
max_num_seqs = 2 # defines max batch size

os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = '64'
os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max_tokens)

# defining here to be able to run/debug directly from VSC (not via terminal)
os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = 'eager'
os.environ['VLLM_SPYRE_USE_CB'] = '1'
os.environ['VLLM_USE_V1'] = '1'

os.environ['VLLM_SPYRE_MAX_CONTEXT_LENGTH'] = '2048'
os.environ['VLLM_SPYRE_MAX_BATCH_SIZE'] = str(max_num_seqs)

# Sample prompts.
template = (
"Below is an instruction that describes a task. Write a response that "
"appropriately completes the request. Be polite in your response to the "
"user.\n\n### Instruction:\n{}\n\n### Response:")

prompt1 = template.format(
"Provide a list of instructions for preparing chicken soup for a family "
"of four.")

prompt2 = template.format("Provide instructions for preparing chicken soup.")

prompt3 = template.format(
"Provide a list of instructions for preparing chicken soup for a family.")

prompts = [
prompt1,
prompt2,
prompt3,
]

# Create a sampling params object.
sampling_params1 = SamplingParams(max_tokens=max_tokens1,
temperature=0.0,
ignore_eos=True)

sampling_params2 = SamplingParams(max_tokens=max_tokens2,
temperature=0.0,
ignore_eos=True)

sampling_params3 = SamplingParams(max_tokens=max_tokens3,
temperature=0.0,
ignore_eos=True)

sampling_params = [
sampling_params1,
sampling_params2,
sampling_params3,
]

# Create an LLM.
llm = LLM(model="/models/llama-194m",
tokenizer="/models/llama-194m",
max_model_len=2048,
block_size=2048)

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
print("=============== GENERATE")
t0 = time.time()
outputs = llm.generate(prompts, sampling_params)
print("Time elaspsed for %d tokens is %.2f sec" %
(len(outputs[0].outputs[0].token_ids), time.time() - t0))
print("===============")
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print("===============")
for output in outputs:
print(output.outputs[0])
1 change: 0 additions & 1 deletion tests/test_spyre_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def test_output(
test using 'pytest --capture=no tests/spyre/test_spyre_tensore_parallel.py'
After debugging, DISABLE_ASSERTS should be reset to 'False'.
'''

max_new_tokens = max([t[1] for t in warmup_shapes])

vllm_sampling_params = SamplingParams(
Expand Down
15 changes: 15 additions & 0 deletions vllm_spyre/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
VLLM_SPYRE_WARMUP_PROMPT_LENS: Optional[List[int]] = None
VLLM_SPYRE_WARMUP_NEW_TOKENS: Optional[List[int]] = None
VLLM_SPYRE_WARMUP_BATCH_SIZES: Optional[List[int]] = None
VLLM_SPYRE_USE_CB: bool = False
VLLM_SPYRE_MAX_BATCH_SIZE: int = 0
VLLM_SPYRE_MAX_CONTEXT_LENGTH: int = 0
Comment on lines +10 to +11
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need these environment variables? Can't we use max-num-seqs and max-model-len directly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it was agreed on in the meeting with the compiler team that these should be env variables. They will be used on their end too...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 🤔 🤔
But the compiler shouldn't be looking at vLLM-specific environment variables, right? That seems like coupling in the wrong way since vllm is a consumer of the compiler, not the other way around. What I would naively expect is that if the compiler requires some env vars to be set, then we would take care of setting them in the plugin code based on vLLM's configuration.

Also, IIUC these values are all currently derivable from the provided warmup shapes, right? So requiring users to configure them here is confusing, and can lead to broken configurations like

VLLM_SPYRE_WARMUP_BATCH_SIZES=1,2,3
VLLM_SPYRE_MAX_BATCH_SIZE=2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, after looking at the scheduler I see that it looks like we're no longer using the static warmup shapes for scheduling with continuous batching. Are those now going to be a relic of the past?

That would be super nice, though I would still say we should be using vllm's existing --max-model-len and --max-num-seqs to keep a single source of configuration for these values

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, warmup shapes will be a relict of the past as we move towards supporting dynamic dimensions. afaik, the way of communication between compiler and vllm is not yet fully determined and it was decided in one of the meetings with the compiler team that (for the time being) there will be two env variables used for sharing information between vllm and the compiler. I do agree, that they eventually will be set by the compiler, but as we emulate on CPU here (hence no AIU Spyre compiler involved), we simply set them ourselves.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be okay to address the proper args handling in another PR? To me it is not straight forward to see why we have/need two calls to check_and_update_config in platform.py and why scheduler_config.max_num_seqs varies between the two. Also this is not specific to this branch (happens on main too). Of course if anyone has an immediate solution, I am happy to include it here:)

Copy link
Member

@tdoublep tdoublep Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can address it as follow up, fine with me.

Are those now going to be a relic of the past?

And to address @joerunde's question here: yes, the warmup shapes will be a relic of the past. Things start to become much more similar to how to works on GPU.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the warmup shapes will be a relic of the past

nice!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have found the issue as to why there are two calls to check_and_update_config in platform.py - will update shortly!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check out #114


environment_variables: Dict[str, Callable[[], Any]] = {
# Defines the prompt lengths the Spyre accelerator should be prepared
Expand Down Expand Up @@ -40,6 +43,18 @@
# - "eager": Skip compile entirely (for debug and testing
"VLLM_SPYRE_DYNAMO_BACKEND":
lambda: os.getenv("VLLM_SPYRE_DYNAMO_BACKEND", "sendnn_decoder"),

# If set, use the V1 continuous batching implementation
"VLLM_SPYRE_USE_CB":
lambda: bool(int(os.getenv("VLLM_SPYRE_USE_CB", "0"))),

# Maximal supported batch size
"VLLM_SPYRE_MAX_BATCH_SIZE":
lambda: int(os.getenv("VLLM_SPYRE_MAX_BATCH_SIZE", "0")),

# Maximal supported context length
"VLLM_SPYRE_MAX_CONTEXT_LENGTH":
lambda: int(os.getenv("VLLM_SPYRE_MAX_CONTEXT_LENGTH", "0")),
}


Expand Down
Loading