Skip to content

Commit eb21f7a

Browse files
authored
Bump vllm to v0.10.1 and add compatibility code (#443)
Signed-off-by: Max de Bayser <[email protected]>
1 parent 203eb21 commit eb21f7a

File tree

4 files changed

+74
-27
lines changed

4 files changed

+74
-27
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ license = {text = "Apache 2"}
1313
dependencies = [
1414
"fms-model-optimizer[fp8]>=0.6.0",
1515
"ibm-fms>=1.2.1",
16-
"vllm>=0.9.2,<=0.10.0",
16+
"vllm>=0.9.2,<=0.10.1.1",
1717
]
1818
requires-python = ">=3.9"
1919
dynamic = ["version"]

tests/utils/test_upstream_compatibility.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,20 @@ def test_mm_inputs():
168168
"renamed mm_inputs to mm_kwargs.")
169169
# The compat code introduced in the PR below can now be removed:
170170
# https://github.com/vllm-project/vllm-spyre/pull/380
171+
172+
173+
@pytest.mark.cpu
174+
def test_init_builtin_logitsprocs():
175+
176+
import vllm.v1.sample.logits_processor
177+
has_init_builtin_logitsprocs = hasattr(vllm.v1.sample.logits_processor,
178+
"init_builtin_logitsprocs")
179+
180+
if VLLM_VERSION == "vLLM:main":
181+
assert not has_init_builtin_logitsprocs
182+
elif VLLM_VERSION == "vLLM:lowest":
183+
assert has_init_builtin_logitsprocs, (
184+
"The lowest supported vLLM version already"
185+
"refactored init_builtin_logitsprocs.")
186+
# The compat code introduced in the PR below can now be removed:
187+
# https://github.com/vllm-project/vllm-spyre/pull/443

vllm_spyre/v1/worker/spyre_input_batch.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,17 @@
55

66
from abc import abstractmethod
77
from dataclasses import dataclass, field
8-
from typing import Generic, Optional, TypeVar, cast
8+
from typing import Any, Generic, Optional, TypeVar, cast
99

1010
import numpy as np
1111
import torch
12+
import vllm.v1.sample.logits_processor
13+
from vllm.config import VllmConfig
1214
from vllm.pooling_params import PoolingParams
1315
from vllm.sampling_params import SamplingParams, SamplingType
1416
from vllm.v1.pool.metadata import PoolingMetadata
1517
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
16-
MoveDirectionality,
17-
init_builtin_logitsprocs)
18+
MoveDirectionality)
1819
from vllm.v1.sample.metadata import SamplingMetadata
1920

2021

@@ -200,6 +201,29 @@ def num_tokens(self) -> int:
200201
return len(self.prompt_token_ids) + len(self.output_token_ids)
201202

202203

204+
# Compatibility code, remove when no supported version
205+
# has init_builtin_logitsprocs any more
206+
def get_builtin_logits_processors(
207+
vllm_config: Optional[VllmConfig] = None) -> Any:
208+
if hasattr(vllm.v1.sample.logits_processor, "LogitsProcessors"):
209+
if vllm_config is None:
210+
return vllm.v1.sample.logits_processor.LogitsProcessors()
211+
return vllm.v1.sample.logits_processor.LogitsProcessors(
212+
ctor(vllm_config, "cpu", False)
213+
for ctor in vllm.v1.sample.logits_processor.\
214+
BUILTIN_LOGITS_PROCESSORS)
215+
else:
216+
if vllm_config is None:
217+
return vllm.v1.sample.logits_processor.LogitsProcessorManager(
218+
non_argmax_invariant=[],
219+
argmax_invariant=[],
220+
)
221+
return vllm.v1.sample.logits_processor.init_builtin_logitsprocs(
222+
pin_memory_available=False,
223+
max_num_reqs=vllm_config.scheduler_config.max_num_seqs + 1,
224+
device="cpu")
225+
226+
203227
class SamplingInputBatch(BaseInputBatch[SamplingRequestState]):
204228
'''
205229
This class was based on the InputBatch for GPU of vLLM V1.
@@ -229,6 +253,8 @@ def __init__(
229253
device: torch.device,
230254
pin_memory: bool,
231255
vocab_size: int,
256+
# Type here is any for compatibility reasons
257+
logitsprocs: Optional[Any] = None,
232258
):
233259
super().__init__(
234260
max_num_reqs,
@@ -297,13 +323,7 @@ def __init__(
297323
# updates. Should reset each step.
298324
self.batch_update_builder = BatchUpdateBuilder()
299325

300-
# Define logits processors.
301-
# TODO(andy): logits processor list should be extensible via engine
302-
# constructor argument; for now the list is fixed.
303-
self.logitsprocs = init_builtin_logitsprocs(pin_memory_available=False,
304-
max_num_reqs=max_num_reqs +
305-
1,
306-
device=device)
326+
self.logitsprocs = logitsprocs or get_builtin_logits_processors()
307327

308328
self.has_allowed_token_ids: set[str] = set()
309329
self.allowed_token_ids_mask: Optional[torch.Tensor] = None

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,9 @@
2929
from vllm_spyre.platform import SpyrePlatform
3030
# yapf conflicts with ruff for this block
3131
# yapf: disable
32-
from vllm_spyre.v1.worker.spyre_input_batch import (BaseInputBatch,
33-
BaseRequestState,
34-
PoolingInputBatch,
35-
PoolingRequestState,
36-
SamplingInputBatch,
37-
SamplingRequestState)
32+
from vllm_spyre.v1.worker.spyre_input_batch import (
33+
BaseInputBatch, BaseRequestState, PoolingInputBatch, PoolingRequestState,
34+
SamplingInputBatch, SamplingRequestState, get_builtin_logits_processors)
3835

3936
# yapf: enable
4037
if TYPE_CHECKING:
@@ -306,12 +303,17 @@ def load_model(self, prompt_lens: Iterable[int],
306303
)
307304

308305
def build_input_batch(self) -> SamplingInputBatch:
306+
# Define logits processors.
307+
# TODO(Max): logits processor list should be extensible via engine
308+
# constructor argument; for now the list is fixed to builtin processors
309+
logits_processors = get_builtin_logits_processors(self.vllm_config)
309310
return SamplingInputBatch(
310311
max_num_reqs=self.scheduler_config.max_num_seqs,
311312
max_model_len=self.model_config.max_model_len,
312313
device=self.device,
313314
pin_memory=self.pin_memory,
314315
vocab_size=self.model_config.get_vocab_size(),
316+
logitsprocs=logits_processors,
315317
)
316318

317319
@property
@@ -810,8 +812,7 @@ def __init__(
810812
max_model_len=vllm_config.model_config.max_model_len,
811813
device=self.device,
812814
pin_memory=self.pin_memory,
813-
vocab_size=vllm_config.model_config.get_vocab_size(),
814-
)
815+
vocab_size=vllm_config.model_config.get_vocab_size())
815816

816817
def pre_warmup(self) -> None:
817818
# Set the number of kv cache blocks to the minimal value of 2 which is
@@ -1351,9 +1352,18 @@ def build_input_batch(self) -> PoolingInputBatch:
13511352
def load_model(self, prompt_lens: Iterable[int],
13521353
num_decode_tokens: Iterable[int]) -> None:
13531354

1354-
if self.model_config.task == "embed":
1355+
task = self.model_config.task
1356+
if task is None:
1357+
# Task is being deprecated upstream because the models
1358+
# support several tasks at once. But for now, here we need
1359+
# to know the task to load the model with
1360+
# AutoModelForSequenceClassification
1361+
task = self.model_config._get_default_pooling_task(
1362+
self.model_config.architectures)
1363+
1364+
if task == "embed":
13551365
self.model = AutoModel.from_pretrained(self.model_config.model)
1356-
elif self.model_config.task == "classify":
1366+
elif task == "classify":
13571367
class_model = AutoModelForSequenceClassification.from_pretrained(
13581368
self.model_config.model)
13591369
if hasattr(class_model, "bert"):
@@ -1368,7 +1378,7 @@ def load_model(self, prompt_lens: Iterable[int],
13681378
"Bert or Roberta for sequence classification")
13691379
self.classifier = class_model.classifier
13701380
else:
1371-
raise ValueError(f"Unsupported task {self.model_config.task}")
1381+
raise ValueError(f"Unsupported task {task}")
13721382

13731383
model_class_name = type(self.model).__name__
13741384
self.is_roberta = "roberta" in model_class_name.lower()
@@ -1393,7 +1403,7 @@ def load_model(self, prompt_lens: Iterable[int],
13931403
dynamic=False,
13941404
backend=envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND)
13951405

1396-
if self.model_config.task == "classify":
1406+
if task == "classify":
13971407
tokenizer = AutoTokenizer.from_pretrained(self.model_config.model)
13981408
output = tokenizer(text="foo", text_pair="bar")
13991409
self.use_token_type_ids = "token_type_ids" in output
@@ -1404,13 +1414,13 @@ def load_model(self, prompt_lens: Iterable[int],
14041414
if hasattr(Pooler, "from_config_with_defaults"):
14051415
# TODO: remove this when we no longer support
14061416
# vllm version v0.9.2
1407-
if self.model_config.task == "embed":
1417+
if task == "embed":
14081418
self.pooler = Pooler.from_config_with_defaults(
14091419
pooler_config,
14101420
pooling_type=PoolingType.CLS,
14111421
normalize=True,
14121422
softmax=False)
1413-
elif self.model_config.task == "classify":
1423+
elif task == "classify":
14141424
self.pooler = ClassifierPooler(config=self.model_config,
14151425
pooler=self._pooler,
14161426
classifier=self.classifier)
@@ -1428,10 +1438,10 @@ def load_model(self, prompt_lens: Iterable[int],
14281438
if 'default_pooling_type' in annotations:
14291439
extra_args['default_pooling_type'] = PoolingType.CLS
14301440

1431-
if self.model_config.task == "embed":
1441+
if task == "embed":
14321442
self.pooler = Pooler.for_embed(pooler_config=pooler_config,
14331443
**extra_args)
1434-
elif self.model_config.task == "classify":
1444+
elif task == "classify":
14351445
self.pooler = ClassifierPooler(
14361446
pooling=self._pooler,
14371447
classifier=self.classifier,

0 commit comments

Comments
 (0)