Skip to content

Commit c00da14

Browse files
authored
🥅 disable v0 decoders (#242)
# Description This updates our platform.py to disable v0 for generation, since it will be no longer supported in fms 1.1.0 Related to #241 Unfortunately the entire v0 stack needs to remain in place for embeddings models, which require the v0 scheduler, v0 worker, and v0 embeddings model runner (which is a child of the v0 decoder model runner) --------- Signed-off-by: Joe Runde <[email protected]>
1 parent 14dd264 commit c00da14

File tree

1 file changed

+28
-29
lines changed

1 file changed

+28
-29
lines changed

vllm_spyre/platform.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,33 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
6565
if scheduler_config.is_multi_step:
6666
raise NotImplementedError
6767

68-
# Near future TODO: vLLM will have an api to check whether v0 or v1 is
69-
# used that isn't just checking the environment variable
68+
is_decoder = model_config.task == "generate"
69+
is_embedding = model_config.task == "embed"
70+
71+
# v0 is only supported for embedding models, and embedding models must
72+
# be run on v0
73+
if is_embedding and envs.VLLM_USE_V1:
74+
raise ValueError("Embedding models are only supported on v0")
75+
elif is_decoder and not envs.VLLM_USE_V1:
76+
raise ValueError("Decoder models are only supported on v1")
77+
elif not is_decoder and not is_embedding:
78+
raise ValueError("Only the 'generate' and 'embed' tasks are "
79+
"supported")
7080

7181
if parallel_config.worker_cls == "auto":
7282
parallel_config.worker_cls = (
7383
f'vllm_spyre{".v1" if envs.VLLM_USE_V1 else ""}'\
7484
'.worker.spyre_worker.SpyreWorker')
7585

76-
if not envs_spyre.VLLM_SPYRE_USE_CB: # no CB
86+
if envs_spyre.VLLM_SPYRE_USE_CB and is_decoder:
87+
scheduler_config.scheduler_cls = "vllm_spyre.v1.core."\
88+
"scheduler.ContinuousBatchingSpyreScheduler"
89+
else:
90+
# Static batching or embedding model.
7791
# Override --max-num-seqs to the biggest warmup batch size
7892
# And override --max-model-len to the biggest warmup sequence
7993
cls._warmup_shapes = None
80-
max_model_len = model_config.max_model_len \
81-
if model_config is not None else sys.maxsize
94+
max_model_len = model_config.max_model_len
8295
spyre_warmup_shapes = cls.get_warmup_shapes(
8396
scheduler_config, max_model_len)
8497
max_batch_size = 0
@@ -87,30 +100,17 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
87100
max_batch_size = max(max_batch_size, shape["batch_size"])
88101
max_seq_len = max(max_seq_len,
89102
shape["prompt_length"] + shape["new_tokens"])
90-
if model_config is not None:
91-
model_config.max_model_len = max_seq_len
92103

93-
if envs.VLLM_USE_V1: # No CB with V1
104+
model_config.max_model_len = max_seq_len
105+
scheduler_config.max_num_seqs = max_batch_size
94106

107+
if is_decoder:
95108
scheduler_config.scheduler_cls = (
96109
"vllm_spyre.v1.core.scheduler."\
97110
"StaticBatchingSpyreScheduler")
98-
scheduler_config.max_num_seqs = max_batch_size
99-
else: # No CB with V0
111+
elif is_embedding:
100112
scheduler_config.scheduler_cls = (
101113
"vllm_spyre.core.scheduler.SpyreScheduler")
102-
else: # CB related checks
103-
if not envs.VLLM_USE_V1: # CB with V0
104-
raise NotImplementedError(
105-
"Continuous batching is only implemented for vLLM V1")
106-
else: # CB with V1
107-
# As of 0.7.3 the scheduler for V1 isn't actually pluggable like
108-
# this yet
109-
scheduler_config.scheduler_cls = "vllm_spyre.v1.core."\
110-
"scheduler.ContinuousBatchingSpyreScheduler"
111-
112-
# Cache and model config aren't set in the individual worker procs
113-
# These are set in the main engine process
114114

115115
# To disable any paged attention ops in the base scheduler, we:
116116
# - Set the block size (in tokens) to the maximum sequence length
@@ -137,12 +137,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
137137
logger.info(
138138
"Overriding configurations based on warmup shapes. "
139139
"max_model_len=%d, max_num_seqs=%d, block_size=%d, "
140-
"num_gpu_blocks_override=%d",
141-
model_config.max_model_len,
142-
scheduler_config.max_num_seqs,
143-
cache_config.block_size,
144-
cache_config.num_gpu_blocks_override,
145-
)
140+
"num_gpu_blocks_override=%d, max_num_batched_tokens=%d",
141+
model_config.max_model_len, scheduler_config.max_num_seqs,
142+
cache_config.block_size, cache_config.num_gpu_blocks_override,
143+
scheduler_config.max_num_batched_tokens)
146144

147145
# set env vars for torch_sendnn to consume
148146
os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str(
@@ -227,7 +225,8 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
227225
"""Returns whether the current platform can support v1 for the supplied
228226
model configuration.
229227
"""
230-
return True
228+
# We don't have an embedding runner for v1 yet
229+
return model_config.task != "embed"
231230

232231
@classmethod
233232
def validate_request(

0 commit comments

Comments
 (0)