Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/spyre_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def create_random_request(
mm_placeholders=None,
sampling_params=sampling_params,
eos_token_id=None,
arrival_time=0,
arrival_time=time.monotonic(),
lora_request=None,
data_parallel_rank=None,
pooling_params=None,
Expand All @@ -350,7 +350,7 @@ def create_random_request(
multi_modal_placeholders=None,
sampling_params=sampling_params,
eos_token_id=None,
arrival_time=0,
arrival_time=time.monotonic(),
lora_request=None,
pooling_params=None,
cache_salt=None,
Expand Down
6 changes: 6 additions & 0 deletions vllm_spyre/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ def _backend_backwards_compat() -> str:
"VLLM_SPYRE_N_TOKENS_PREFILL_PRIO":
lambda: int(os.getenv("VLLM_SPYRE_N_TOKENS_PREFILL_PRIO", "-1")),

# scheduling heuristic: maximal waiting (blocking) time for prefill
# Prefills waiting longer than VLLM_SPYRE_MAX_WAITING_TIME_SECONDS
# seconds will have priority after the current decode batch has finished.
"VLLM_SPYRE_MAX_WAITING_TIME_SECONDS":
lambda: float(os.getenv("VLLM_SPYRE_MAX_WAITING_TIME_SECONDS", "inf")),

# Allow vllm-spyre to update env vars related to multi-threading (eg. OMP)
# based on the detected CPU cores and server configuration
"VLLM_SPYRE_UPDATE_THREAD_CONFIG":
Expand Down
14 changes: 14 additions & 0 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,20 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
envs_spyre.VLLM_SPYRE_N_TOKENS_PREFILL_PRIO,
envs_spyre.VLLM_SPYRE_N_TOKENS_PREFILL_PRIO)

# scheduling heuristic: maximal waiting (blocking) time for prefill
if math.isinf(envs_spyre.VLLM_SPYRE_MAX_WAITING_TIME_SECONDS):
logger.info(
"Env var VLLM_SPYRE_MAX_WAITING_TIME_SECONDS determining the "
"maximal waiting time for a request unset. Defaulting to inf, "
"which is infinite time (no scheduler heuristic at all).")
else:
logger.info(
"Env var VLLM_SPYRE_MAX_WAITING_TIME_SECONDS determining the "
"maximal waiting time is set to %ss. This means that prefills "
"waiting longer than %s seconds will always be prioritized. ",
envs_spyre.VLLM_SPYRE_MAX_WAITING_TIME_SECONDS,
envs_spyre.VLLM_SPYRE_MAX_WAITING_TIME_SECONDS)

@classmethod
def use_all_gather(cls) -> bool:
"""
Expand Down
27 changes: 27 additions & 0 deletions vllm_spyre/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
import os
import time
from collections import deque
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -159,6 +160,9 @@ def __init__(self, *args, **kwargs) -> None:
# cache for self.check_batch_tkv_limit() outer key: tuple(request_ids),
# inner key: (request_id, max_batch_tkv_limit), value: (lower, upper)
self._cache_check_batch_tkv_limit: dict[tuple, dict[tuple, tuple]] = {}
# if batch_is_locked: finish current decode batch to serve a request
# that waited for longer than VLLM_SPYRE_MAX_WAITING_TIME_SECONDS
self.batch_is_locked = False

def update_from_output(
self,
Expand All @@ -182,6 +186,11 @@ def schedule(self) -> "SchedulerOutput":
To avoid additional specialization, some requests are held back from the
base scheduler but are restored after.
"""
# unlock the current decode batch if no requests are in running queue
if len(self.running) == 0 and self.batch_is_locked:
self.batch_is_locked = False
logger.debug("Unlocking the current decode batch as no requests "
"are in running queue")
# First purge the full waiting queue into our holdback queue, preserving
# priority
while self.waiting:
Expand Down Expand Up @@ -224,11 +233,29 @@ def can_schedule(self, request) -> bool:
max_prompt_batch_size = 1
max_context_len = self.scheduler_config.max_model_len

# if the batch is locked by a request which has been waiting for
# longer then VLLM_SPYRE_MAX_WAITING_TIME_SECONDS, we cannot
# schedule the current sequence until we have served this request
if self.batch_is_locked:
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of locking the batch entirely, shouldn't we just disallow any skipping of requests in the queue until the request at the head of the waiting queue schedules?

I haven't followed super closely but my assumption is that the blocked request may be able to be scheduled before the full batch finishes. E.g. with the 128k limit, a 64k request could potentially schedule once the batch has drained down to a single other request, so we wouldn't need to wait for the last one to finish.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great idea! I will certainly address that in a follow up. We wanted to keep the first version as simple and fail-proof as possible.


# running and waiting queues are both empty -> start a new batch
# which can always be scheduled
if len(self.running) + len(self.waiting) == 0:
return True

# scheduling heuristic: maximal waiting (blocking) time for prefill
waiting_time = (time.monotonic() - request.arrival_time)
if waiting_time > envs_spyre.VLLM_SPYRE_MAX_WAITING_TIME_SECONDS:
self.batch_is_locked = True
logger.debug("Request %s waited longer (%ss) than " \
"VLLM_SPYRE_MAX_WAITING_TIME_SECONDS (%ss): locking current " \
"decode batch and schedule this request either as part of " \
"the current batch or in an exclusive subsequent new batch.",
request.request_id, round(waiting_time, 2),
envs_spyre.VLLM_SPYRE_MAX_WAITING_TIME_SECONDS
)

# check that there is space in the current decode batch
cond1 = len(self.running) + len(
self.waiting) < self.max_num_running_reqs
Expand Down
Loading