diff --git a/examples/offline_inference_spyre_cb.py b/examples/offline_inference_spyre_cb.py index 940a18d7b..223223385 100644 --- a/examples/offline_inference_spyre_cb.py +++ b/examples/offline_inference_spyre_cb.py @@ -15,6 +15,8 @@ os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' os.environ['VLLM_SPYRE_USE_CB'] = '1' +os.environ['VLLM_SPYRE_MAX_BATCH_SIZE'] = '4' +os.environ['VLLM_SPYRE_MAX_CONTEXT_LENGTH'] = '2048' # Sample prompts. template = ( diff --git a/vllm_spyre/envs.py b/vllm_spyre/envs.py index 52153f99e..6eaaa90bd 100644 --- a/vllm_spyre/envs.py +++ b/vllm_spyre/envs.py @@ -7,6 +7,8 @@ VLLM_SPYRE_WARMUP_NEW_TOKENS: Optional[List[int]] = None VLLM_SPYRE_WARMUP_BATCH_SIZES: Optional[List[int]] = None VLLM_SPYRE_USE_CB: bool = False + VLLM_SPYRE_MAX_BATCH_SIZE: int = 0 + VLLM_SPYRE_MAX_CONTEXT_LENGTH: int = 0 environment_variables: Dict[str, Callable[[], Any]] = { # Defines the prompt lengths the Spyre accelerator should be prepared @@ -45,6 +47,14 @@ # If set, use the V1 continuous batching implementation "VLLM_SPYRE_USE_CB": lambda: bool(int(os.getenv("VLLM_SPYRE_USE_CB", "0"))), + + # Maximal supported batch size + "VLLM_SPYRE_MAX_BATCH_SIZE": + lambda: int(os.getenv("VLLM_SPYRE_MAX_BATCH_SIZE", "0")), + + # Maximal supported context length + "VLLM_SPYRE_MAX_CONTEXT_LENGTH": + lambda: int(os.getenv("VLLM_SPYRE_MAX_CONTEXT_LENGTH", "0")), } diff --git a/vllm_spyre/model_executor/model_loader/spyre.py b/vllm_spyre/model_executor/model_loader/spyre.py index 61e4914d3..be0276345 100644 --- a/vllm_spyre/model_executor/model_loader/spyre.py +++ b/vllm_spyre/model_executor/model_loader/spyre.py @@ -15,7 +15,6 @@ from vllm.model_executor.model_loader.weight_utils import ( download_weights_from_hf) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform import vllm_spyre.envs as envs_spyre @@ -346,15 +345,8 @@ def __init__( max_decode_length) # physical KV cache (fms wrapper/ AIU Spyre) - # lives in SpyreCausalLM only for convenient model access - warmup_shapes = current_platform.get_warmup_shapes() - max_batch = max(shape["batch_size"] for shape in warmup_shapes) - max_prompt_length = max(shape["prompt_length"] - for shape in warmup_shapes) - max_new_tokens = max(shape["new_tokens"] for shape in warmup_shapes) - # Eventually max_model_len = self.config.max_position_embeddings, - # but saving some memory here to only allocate the max in practise - max_model_len = max_prompt_length + max_new_tokens + max_batch = envs_spyre.VLLM_SPYRE_MAX_BATCH_SIZE + max_model_len = envs_spyre.VLLM_SPYRE_MAX_CONTEXT_LENGTH if self.config.model_type == 'llama': num_layers = self.config.num_hidden_layers