Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/offline_inference_spyre_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ['VLLM_SPYRE_USE_CB'] = '1'
os.environ['VLLM_SPYRE_MAX_BATCH_SIZE'] = '4'
os.environ['VLLM_SPYRE_MAX_CONTEXT_LENGTH'] = '2048'

# Sample prompts.
template = (
Expand Down
10 changes: 10 additions & 0 deletions vllm_spyre/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
VLLM_SPYRE_WARMUP_NEW_TOKENS: Optional[List[int]] = None
VLLM_SPYRE_WARMUP_BATCH_SIZES: Optional[List[int]] = None
VLLM_SPYRE_USE_CB: bool = False
VLLM_SPYRE_MAX_BATCH_SIZE: int = 0
VLLM_SPYRE_MAX_CONTEXT_LENGTH: int = 0

environment_variables: Dict[str, Callable[[], Any]] = {
# Defines the prompt lengths the Spyre accelerator should be prepared
Expand Down Expand Up @@ -45,6 +47,14 @@
# If set, use the V1 continuous batching implementation
"VLLM_SPYRE_USE_CB":
lambda: bool(int(os.getenv("VLLM_SPYRE_USE_CB", "0"))),

# Maximal supported batch size
"VLLM_SPYRE_MAX_BATCH_SIZE":
lambda: int(os.getenv("VLLM_SPYRE_MAX_BATCH_SIZE", "0")),

# Maximal supported context length
"VLLM_SPYRE_MAX_CONTEXT_LENGTH":
lambda: int(os.getenv("VLLM_SPYRE_MAX_CONTEXT_LENGTH", "0")),
}


Expand Down
12 changes: 2 additions & 10 deletions vllm_spyre/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform

import vllm_spyre.envs as envs_spyre

Expand Down Expand Up @@ -346,15 +345,8 @@ def __init__(
max_decode_length)

# physical KV cache (fms wrapper/ AIU Spyre)
# lives in SpyreCausalLM only for convenient model access
warmup_shapes = current_platform.get_warmup_shapes()
max_batch = max(shape["batch_size"] for shape in warmup_shapes)
max_prompt_length = max(shape["prompt_length"]
for shape in warmup_shapes)
max_new_tokens = max(shape["new_tokens"] for shape in warmup_shapes)
# Eventually max_model_len = self.config.max_position_embeddings,
# but saving some memory here to only allocate the max in practise
max_model_len = max_prompt_length + max_new_tokens
max_batch = envs_spyre.VLLM_SPYRE_MAX_BATCH_SIZE
max_model_len = envs_spyre.VLLM_SPYRE_MAX_CONTEXT_LENGTH

if self.config.model_type == 'llama':
num_layers = self.config.num_hidden_layers
Expand Down