@@ -65,20 +65,33 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
6565 if scheduler_config .is_multi_step :
6666 raise NotImplementedError
6767
68- # Near future TODO: vLLM will have an api to check whether v0 or v1 is
69- # used that isn't just checking the environment variable
68+ is_decoder = model_config .task == "generate"
69+ is_embedding = model_config .task == "embed"
70+
71+ # v0 is only supported for embedding models, and embedding models must
72+ # be run on v0
73+ if is_embedding and envs .VLLM_USE_V1 :
74+ raise ValueError ("Embedding models are only supported on v0" )
75+ elif is_decoder and not envs .VLLM_USE_V1 :
76+ raise ValueError ("Decoder models are only supported on v1" )
77+ elif not is_decoder and not is_embedding :
78+ raise ValueError ("Only the 'generate' and 'embed' tasks are "
79+ "supported" )
7080
7181 if parallel_config .worker_cls == "auto" :
7282 parallel_config .worker_cls = (
7383 f'vllm_spyre{ ".v1" if envs .VLLM_USE_V1 else "" } ' \
7484 '.worker.spyre_worker.SpyreWorker' )
7585
76- if not envs_spyre .VLLM_SPYRE_USE_CB : # no CB
86+ if envs_spyre .VLLM_SPYRE_USE_CB and is_decoder :
87+ scheduler_config .scheduler_cls = "vllm_spyre.v1.core." \
88+ "scheduler.ContinuousBatchingSpyreScheduler"
89+ else :
90+ # Static batching or embedding model.
7791 # Override --max-num-seqs to the biggest warmup batch size
7892 # And override --max-model-len to the biggest warmup sequence
7993 cls ._warmup_shapes = None
80- max_model_len = model_config .max_model_len \
81- if model_config is not None else sys .maxsize
94+ max_model_len = model_config .max_model_len
8295 spyre_warmup_shapes = cls .get_warmup_shapes (
8396 scheduler_config , max_model_len )
8497 max_batch_size = 0
@@ -87,30 +100,17 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
87100 max_batch_size = max (max_batch_size , shape ["batch_size" ])
88101 max_seq_len = max (max_seq_len ,
89102 shape ["prompt_length" ] + shape ["new_tokens" ])
90- if model_config is not None :
91- model_config .max_model_len = max_seq_len
92103
93- if envs .VLLM_USE_V1 : # No CB with V1
104+ model_config .max_model_len = max_seq_len
105+ scheduler_config .max_num_seqs = max_batch_size
94106
107+ if is_decoder :
95108 scheduler_config .scheduler_cls = (
96109 "vllm_spyre.v1.core.scheduler." \
97110 "StaticBatchingSpyreScheduler" )
98- scheduler_config .max_num_seqs = max_batch_size
99- else : # No CB with V0
111+ elif is_embedding :
100112 scheduler_config .scheduler_cls = (
101113 "vllm_spyre.core.scheduler.SpyreScheduler" )
102- else : # CB related checks
103- if not envs .VLLM_USE_V1 : # CB with V0
104- raise NotImplementedError (
105- "Continuous batching is only implemented for vLLM V1" )
106- else : # CB with V1
107- # As of 0.7.3 the scheduler for V1 isn't actually pluggable like
108- # this yet
109- scheduler_config .scheduler_cls = "vllm_spyre.v1.core." \
110- "scheduler.ContinuousBatchingSpyreScheduler"
111-
112- # Cache and model config aren't set in the individual worker procs
113- # These are set in the main engine process
114114
115115 # To disable any paged attention ops in the base scheduler, we:
116116 # - Set the block size (in tokens) to the maximum sequence length
@@ -137,12 +137,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
137137 logger .info (
138138 "Overriding configurations based on warmup shapes. "
139139 "max_model_len=%d, max_num_seqs=%d, block_size=%d, "
140- "num_gpu_blocks_override=%d" ,
141- model_config .max_model_len ,
142- scheduler_config .max_num_seqs ,
143- cache_config .block_size ,
144- cache_config .num_gpu_blocks_override ,
145- )
140+ "num_gpu_blocks_override=%d, max_num_batched_tokens=%d" ,
141+ model_config .max_model_len , scheduler_config .max_num_seqs ,
142+ cache_config .block_size , cache_config .num_gpu_blocks_override ,
143+ scheduler_config .max_num_batched_tokens )
146144
147145 # set env vars for torch_sendnn to consume
148146 os .environ ["VLLM_DT_MAX_CONTEXT_LEN" ] = str (
@@ -227,7 +225,8 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
227225 """Returns whether the current platform can support v1 for the supplied
228226 model configuration.
229227 """
230- return True
228+ # We don't have an embedding runner for v1 yet
229+ return model_config .task != "embed"
231230
232231 @classmethod
233232 def validate_request (
0 commit comments