Skip to content

Commit 377895d

Browse files
authored
removing legacy backward compatibility (#313)
### removing legacy backward compatibility due to the requirement `vllm>=0.9.2` we can safely revert some changes by #245 . --------- Signed-off-by: Yannick Schnider <[email protected]> Signed-off-by: Yannick Schnider <[email protected]>
1 parent e18ec15 commit 377895d

File tree

3 files changed

+9
-29
lines changed

3 files changed

+9
-29
lines changed

tests/spyre_util.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -543,13 +543,6 @@ def create_random_request(
543543
request_id: int, num_tokens: int,
544544
sampling_params: SamplingParams) -> EngineCoreRequest:
545545

546-
# Temporary until these parameters make it to a release version in vllm
547-
extra_kwargs: dict[str, Any] = {}
548-
if "data_parallel_rank" in EngineCoreRequest.__annotations__:
549-
extra_kwargs["data_parallel_rank"] = None
550-
if "pooling_params" in EngineCoreRequest.__annotations__:
551-
extra_kwargs["pooling_params"] = None
552-
553546
return EngineCoreRequest(request_id=str(request_id),
554547
prompt_token_ids=[request_id] * num_tokens,
555548
mm_inputs=None,
@@ -559,8 +552,9 @@ def create_random_request(
559552
eos_token_id=None,
560553
arrival_time=0,
561554
lora_request=None,
562-
cache_salt=None,
563-
**extra_kwargs)
555+
data_parallel_rank=None,
556+
pooling_params=None,
557+
cache_salt=None)
564558

565559

566560
def skip_unsupported_tp_size(size: int, backend: str):

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import deque
44
from collections.abc import Iterable
55
from dataclasses import asdict, dataclass
6-
from typing import TYPE_CHECKING, Any, Optional, cast
6+
from typing import TYPE_CHECKING, Optional, cast
77

88
import torch
99
from torch import nn
@@ -426,10 +426,6 @@ def execute_model(
426426
req, str) else self.requests[req]
427427
req_state.output_token_ids.extend(sampled_ids[i])
428428

429-
extra_kwargs: dict[str, Any] = {}
430-
if "pooler_output" in ModelRunnerOutput.__dataclass_fields__:
431-
extra_kwargs["pooler_output"] = None
432-
433429
prompt_logprobs_dicts = self._get_prompt_logprobs_dict(
434430
logits=logits, model_inputs=model_input)
435431

@@ -445,7 +441,7 @@ def execute_model(
445441
logprobs=(output.logprobs_tensors.tolists()
446442
if output.logprobs_tensors else None),
447443
prompt_logprobs_dict=prompt_logprobs_dicts,
448-
**extra_kwargs,
444+
pooler_output=None,
449445
)
450446

451447
return model_output

vllm_spyre/v1/worker/spyre_worker.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import platform
66
import signal
77
import time
8-
from typing import Any, Optional, Union, cast
8+
from typing import Optional, Union, cast
99

1010
import torch
1111
import torch.distributed as dist
@@ -319,11 +319,6 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
319319
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
320320
0, len(valid_token_ids_tensor), (batch_size + 1, prompt_len))]
321321

322-
# TODO temporary until 'pooling_params' makes it to a release version
323-
# in vllm
324-
extra_kwargs: dict[str, Any] = {}
325-
if "pooling_params" in NewRequestData.__dataclass_fields__:
326-
extra_kwargs["pooling_params"] = None
327322
dummy_requests = [
328323
NewRequestData(
329324
req_id="warmup-%d" % (i),
@@ -335,7 +330,8 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
335330
block_ids=[0], # not actually used
336331
num_computed_tokens=0,
337332
lora_request=None,
338-
**extra_kwargs) for i in range(batch_size + 1)
333+
pooling_params=None,
334+
) for i in range(batch_size + 1)
339335
]
340336
add_dummy_request = dummy_requests.pop(-1)
341337

@@ -473,12 +469,6 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
473469
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
474470
0, len(valid_token_ids_tensor), (batch_size, prompt_len))]
475471

476-
# TODO temporary until 'pooling_params' makes it to a release version
477-
# in vllm
478-
extra_kwargs: dict[str, Any] = {}
479-
if "pooling_params" in NewRequestData.__dataclass_fields__:
480-
extra_kwargs["pooling_params"] = None
481-
482472
# Set up dummy requests for prefill steps
483473
dummy_requests = [
484474
NewRequestData(
@@ -491,7 +481,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
491481
block_ids=[0],
492482
num_computed_tokens=0,
493483
lora_request=None,
494-
**extra_kwargs) for i in range(batch_size)
484+
pooling_params=None) for i in range(batch_size)
495485
]
496486

497487
# Set up dummy cached_requests for decode steps

0 commit comments

Comments
 (0)