Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-spyre.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ jobs:
export DISTRIBUTED_STRATEGY_IGNORE_MODULES=WordEmbedding && \
cd vllm-spyre && \
python -m pytest --timeout=300 tests -v -k "V0 and eager" && \
python -m pytest --forked --timeout=300 tests -v -k "V1 and eager"
python -m pytest --forked --timeout=300 tests -v -k "V1- and eager"
'''
8 changes: 4 additions & 4 deletions Dockerfile.spyre
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ RUN ln -sf $(which python${PYTHON_VERSION}) /usr/bin/python && \
# Download and install vllm ###########################################################
RUN git clone --depth 1 https://github.com/vllm-project/vllm.git \
&& cd vllm \
&& git fetch origin pull/14242/head:spyre-workarounds \
&& git checkout spyre-workarounds \
&& git fetch --tags \
&& git checkout v0.8.0 \
&& python -m pip install --upgrade pip \
&& pip3 install torch=="2.5.1+cpu" --index-url https://download.pytorch.org/whl/cpu \
&& python use_existing_torch.py \
&& pip install -r requirements-build.txt \
&& SETUPTOOLS_SCM_PRETEND_VERSION=0.7.3 VLLM_TARGET_DEVICE=empty pip install --verbose . --no-build-isolation
&& pip install -r requirements/build.txt \
&& SETUPTOOLS_SCM_PRETEND_VERSION=0.8.0 VLLM_TARGET_DEVICE=empty pip install --verbose . --no-build-isolation

# Install vllm Spyre plugin ##################################################################
RUN mkdir /workspace/vllm-spyre
Expand Down
160 changes: 114 additions & 46 deletions vllm_spyre/v1/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# SPDX-License-Identifier: Apache-2.0

from collections import deque
from typing import Deque, Optional
from typing import Deque

from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.core.scheduler_output import SchedulerOutput
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs, FinishReason
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

logger = init_logger(__name__)
Expand All @@ -20,22 +21,9 @@ class SpyreScheduler(Scheduler):
- Only schedules batches of requests that fit a common warmup shape
"""

def __init__(
self,
scheduler_config: SchedulerConfig,
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
log_stats: bool,
) -> None:
def __init__(self, *args, **kwargs) -> None:
# Initialize vLLM scheduler
super().__init__(scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
lora_config=lora_config,
speculative_config=speculative_config,
log_stats=log_stats)
super().__init__(*args, **kwargs)

# Add our own state for handling Spyre constraints

Expand All @@ -49,6 +37,51 @@ def __init__(
# scheduler sees have at least one common warmup shape.
self.holdback_queue: Deque[Request] = deque()

self.rejected_requests: set[str] = set()

def add_request(self, request: Request) -> None:
"""This override rejects requests that fit no warmup shape"""
if len(
self._get_matching_warmup_shapes(request=request,
warmup_shapes=list(
self.spyre_warmup_shapes),
current_batch_size=0)) == 0:
logger.warning(
"No applicable warmup shape exists for "
"combination of prompt length (%d tokens) "
"and maximum number of output tokens to be "
"generated (%d tokens) from request id %s",
request.num_prompt_tokens, request.sampling_params.max_tokens,
request.request_id)
# TODO: There are open PRs that should enable raising an error for
# a single request like this, which will gracefully return an error
# for the request, instead of shutting down the engine.
# See https://github.com/vllm-project/vllm/pull/11737
# raise ValueError("Request does not fit any spyre warmup shape")

# For now, we'll insert a dummy request and manually reject it when
# we construct the outputs later
self.rejected_requests.add(request.request_id)
request.prompt_token_ids = [0]
request.num_prompt_tokens = 1
request.sampling_params = SamplingParams(max_tokens=1)

# delegate to super
super().add_request(request=request)

def update_from_output(
self,
scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
) -> EngineCoreOutputs:
"""Temporary override to handle rejected requests that were too large
to schedule."""
reject_outputs = self._handle_rejects()
outputs = super().update_from_output(scheduler_output,
model_runner_output)
outputs.outputs.extend(reject_outputs)
return outputs

def schedule(self) -> "SchedulerOutput":
"""This override adds constraints and then delegates most of the work
to the base scheduler"""
Expand All @@ -71,40 +104,19 @@ def schedule(self) -> "SchedulerOutput":

# prune the possible shapes to only those that fit this request
# and the growing batch size
max_tokens = 0
if request.sampling_params is not None and\
request.sampling_params.max_tokens is not None:
max_tokens = request.sampling_params.max_tokens

available_warmup_shapes = [
shape for shape in available_warmup_shapes
if request.num_prompt_tokens <= shape['prompt_length']
and max_tokens <= shape['new_tokens']
and len(self.waiting) < shape['batch_size']
]
available_warmup_shapes = self._get_matching_warmup_shapes(
request=request,
warmup_shapes=available_warmup_shapes,
current_batch_size=len(self.waiting))

if len(available_warmup_shapes) > 0:
# There is still at least one valid shape, so add to the
# waiting queue
self.waiting.append(self.holdback_queue.popleft())
else:
# We can't schedule this one.
# If it's the first request, then it fits _no_ shapes at all
# So we reject it entirely
if len(self.waiting) == 0:
logger.warning(
"No applicable warmup shape exists for "
"combination of prompt length (%d tokens) "
"and maximum number of output tokens to be "
"generated (%d tokens)", request.num_prompt_tokens,
request.sampling_params.max_tokens)

request.status = RequestStatus.FINISHED_IGNORED
self._free_request(self.holdback_queue.popleft())
else:
# Otherwise, we simply stop here so that the scheduler
# can work with the batch we have
break
# Otherwise, we simply stop here so that the scheduler
# can work with the batch we have
break

logger.debug(
"Scheduling a new batch of %d requests, holding back %d "
Expand All @@ -119,3 +131,59 @@ def schedule(self) -> "SchedulerOutput":
def get_num_unfinished_requests(self) -> int:
# Override this to include our extra queue
return len(self.waiting) + len(self.running) + len(self.holdback_queue)

def _get_matching_warmup_shapes(
self, request: Request, warmup_shapes: list[dict[str, int]],
current_batch_size: int) -> list[dict[str, int]]:
"""Return the subset of shapes that match this request"""
max_tokens = 0
if request.sampling_params is not None and\
request.sampling_params.max_tokens is not None:
max_tokens = request.sampling_params.max_tokens

return [
shape for shape in warmup_shapes
if request.num_prompt_tokens <= shape['prompt_length']
and max_tokens <= shape['new_tokens']
and current_batch_size < shape['batch_size']
]

def _handle_rejects(self) -> list[EngineCoreOutput]:
"""Temporary solution to reject requests that were too large to
schedule. This removes the rejected requests from the scheduler, and
returns empty outputs for them with finish reason `abort`.
Comment on lines +152 to +154
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behaviour in V0 was that if the request didn't match any warmup shape, it would be added to the ignore queue. This results in finish reason 'length' and empty string as output. While we didn't really like this behaviour, it was consistent with how vLLM handled similar cases (e.g., prompt length being longer than max model len) in V0. Has this changed in V1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdoublep Yeah, as far as I can tell the v1 scheduler doesn't have the ability to ignore requests. The new flow in the v1 engine is that the scheduling is broken down into two passes, schedule() and update_from_output(). The engine expects all requests to schedule at least one token eventually from schedule(), and if no tokens are scheduled for an iteration it'll just go idle waiting for more inputs. Once the model has been invoked, update_from_output() is used to construct the outputs for the engine's callers. If we look at the simple case where we get only one single request and it doesn't fit a warmup shape, we can't just not schedule it or the engine will go idle and we never get a chance to pass a result back to the caller. So we have to schedule at least one dummy token, and then pass back an empty result from update_from_output(), which is a bit sad.

As far as I can tell, all of the request validation has been moved ahead of the scheduler, so that for cases like len(prompt) > max_model_len, the user gets an error and the scheduler is never invoked for that request. For online requests, that's handled in the api server for both V0 and V1:

{
  "object": "error",
  "message": "This model's maximum context length is 100 tokens. However, you requested 153 tokens (137 in the messages, 16 in the completion). Please reduce the length of the messages or completion.",
  "type": "BadRequestError",
  "param": null,
  "code": 400
}

But the offline entrypoint doesn't do as much request validation, with V1 the engine handles this during "input processing" and will now raise a ValueError on any invalid inputs:

model = LLM("/models/llama-194m", max_model_len=100)
model.generate("hello 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 hello 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 hello 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 hello 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/senuser/my-vllm/lib64/python3.11/site-packages/vllm/utils.py", line 1080, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/senuser/my-vllm/lib64/python3.11/site-packages/vllm/entrypoints/llm.py", line 462, in generate
    self._validate_and_add_requests(
  File "/home/senuser/my-vllm/lib64/python3.11/site-packages/vllm/entrypoints/llm.py", line 1310, in _validate_and_add_requests
    self._add_request(
  File "/home/senuser/my-vllm/lib64/python3.11/site-packages/vllm/entrypoints/llm.py", line 1328, in _add_request
    self.llm_engine.add_request(
  File "/home/senuser/my-vllm/lib64/python3.11/site-packages/vllm/v1/engine/llm_engine.py", line 164, in add_request
    request = self.processor.process_inputs(request_id, prompt, params,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/senuser/my-vllm/lib64/python3.11/site-packages/vllm/v1/engine/processor.py", line 141, in process_inputs
    self._validate_model_inputs(processed_inputs)
  File "/home/senuser/my-vllm/lib64/python3.11/site-packages/vllm/v1/engine/processor.py", line 266, in _validate_model_inputs
    raise ValueError(
ValueError: Prompt length of 137 is longer than the maximum model length of 100.

IIUC, the input processing step was removed from the main engine loop in V1 so that it could better handle these cases. Previously with V0 any errors would kill the engine completely, now V1 can raise any validation errors it wants before the request gets to the main engine loop. But I do see with Robert's incoming changes here that there should soon be per-request failure handling from inside the engine loop that may allow us to raise a value error from our scheduler, which would be a lot better than this hack.

However, I think the best approach would be to let the plugins hook into the v1 engines' input processing step so that we can properly enforce the new assumption that every request that makes it to the scheduler is schedulable. I can start talking to maintainers about that, but I did still want to get this temporary fix in so we can keep moving forward.

"""
if len(self.rejected_requests) == 0:
return []

# Remove rejected requests from all queues
reject_outputs = self._reject_from_queue(self.running)
reject_outputs.extend(self._reject_from_queue(self.waiting))
reject_outputs.extend(self._reject_from_queue(self.holdback_queue))
self.rejected_requests.clear()

return reject_outputs

def _reject_from_queue(self,
queue: Deque[Request]) -> list[EngineCoreOutput]:
"""Remove rejected requests from a given queue and return a list of
engine core outputs to return for them"""
reject_outputs: list[EngineCoreOutput] = []
rejected_requests: list[Request] = [
request for request in queue
if request.request_id in self.rejected_requests
]

for request in rejected_requests:
queue.remove(request)
reject_outputs.append(
EngineCoreOutput(request.request_id,
new_token_ids=[],
finish_reason=FinishReason.ABORT,
stop_reason="Request did not fit any warmup "
"shape"))
request.status = RequestStatus.FINISHED_ABORTED
Comment on lines +179 to +185
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joerunde Would your solution still work when you put:

            reject_outputs.append(
                EngineCoreOutput(request.request_id,
                                 new_token_ids=[],
                                 finish_reason=FinishReason.LENGTH,
                                 stop_reason="Request did not fit any warmup "
                                 "shape"))
            request.status = RequestStatus.FINISHED_IGNORED

here? As @tdoublep mentioned, this would be more consistent with how we used to handle it previously.
A question regarding your comment: Does this mean that for both online and offline inference the RequestStatus and FinishReason are never set but an error is raised before? Are therefore the FinishReason.LENGTH and RequestStatus.FINISHED_IGNORED obsolete here? If that is the case, I don't have any preference...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah as far as I can tell the RequestStatus.FINISHED_IGNORED: FinishReason.LENGTH entry in that map for v1 is unused. It looks like the new behavior is to raise an error when a prompt is too long. I guess I don't really care either way about what the finish_reason here is, so I can change it to length. I'd like to get this is and then moving forward we can figure out how best to match v1's behavior of raising an error. I'll check on the vllm slack though and make sure that that's the intended behavior

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Also no strong opinion here, just asked for better understanding.

self._free_request(request)
self.rejected_requests.remove(request.request_id)

return reject_outputs
29 changes: 21 additions & 8 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,12 @@ def prepare_model_input(
# TODO: Build the rest of the SamplingMetadata correctly
dummy_tensors = lambda v: torch.full(
(num_reqs, ), v, device=self.device)

# vllm 0.7.3 backwards compatibility
extra_kwargs: dict = {}
if "bad_words_token_ids" in SamplingMetadata.__dataclass_fields__:
extra_kwargs["bad_words_token_ids"] = {}

dummy_metadata = SamplingMetadata(
temperature=dummy_tensors(0.0),
all_greedy=False,
Expand All @@ -323,7 +329,7 @@ def prepare_model_input(
min_tokens={},
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
)
**extra_kwargs)

return ModelInputForSpyre(input_tokens=input_tokens,
input_positions=input_positions,
Expand Down Expand Up @@ -499,10 +505,17 @@ class in the modeling code. Every attention layer populates an entry
"""
# We do at least use the real size from the cache config.
block_size = self.vllm_config.cache_config.block_size
return {
"foo":
FullAttentionSpec(block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16)
}

# vllm 0.7.3 backwards compatibility
try:
attn_spec = FullAttentionSpec(block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16)
except TypeError:
attn_spec = FullAttentionSpec(block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float16,
use_mla=False)
return {"foo": attn_spec}