Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/spyre_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import os
import random
import time
from pathlib import Path
from typing import Any, Optional, Union

Expand Down Expand Up @@ -605,7 +606,7 @@ def create_random_request(
mm_placeholders=None,
sampling_params=sampling_params,
eos_token_id=None,
arrival_time=0,
arrival_time=time.time(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest using time.monotonic() instead to avoid issues with daylight savings etc.

lora_request=None,
data_parallel_rank=None,
pooling_params=None,
Expand All @@ -622,7 +623,7 @@ def create_random_request(
multi_modal_placeholders=None,
sampling_params=sampling_params,
eos_token_id=None,
arrival_time=0,
arrival_time=time.time(),
lora_request=None,
pooling_params=None,
cache_salt=None,
Expand Down
6 changes: 6 additions & 0 deletions vllm_spyre/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def _backend_backwards_compat() -> str:
lambda: bool(int(os.getenv("VLLM_SPYRE_ENABLE_PREFILL_OPTIMIZATION", "0"))
),

# scheduling heuristic: maximal waiting (blocking) time for prefill
# Prefills waiting longer than VLLM_SPYRE_MAX_WAITING_TIME_PREFILL
# seconds will have priority after the current decode batch has finished.
"VLLM_SPYRE_MAX_WAITING_TIME_PREFILL":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name should reflect the units of time that are being used (e.g., VLLM_SPYRE_MAX_WAITING_TIME_SECONDS) or something. Should we also consider using an integer instead of a float?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that int vs float has already been considered - please ignore that part.

lambda: int(os.getenv("VLLM_SPYRE_MAX_WAITING_TIME_PREFILL", "-1")),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this also be a float so that the user can specify 0.5 for 500ms?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, I just changed that


# Allow vllm-spyre to update env vars related to multi-threading (eg. OMP)
# based on the detected CPU cores and server configuration
"VLLM_SPYRE_UPDATE_THREAD_CONFIG":
Expand Down
14 changes: 14 additions & 0 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,20 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"VLLM_DT_MAX_BATCH_TKV_LIMIT found. Using the default value " \
"(max_model_len * max_batch_size): %d", default_max_batch_tkv_limit)

# scheduling heuristic: maximal waiting (blocking) time for prefill
if envs_spyre.VLLM_SPYRE_MAX_WAITING_TIME_PREFILL == -1:
logger.info(
"Env var VLLM_SPYRE_MAX_WAITING_TIME_PREFILL determining the "
"maximal waiting time for a request unset. Defaulting to -1, "
"which is infinite time (no scheduler heuristic at all).")
else:
logger.info(
"Env var VLLM_SPYRE_MAX_WAITING_TIME_PREFILL determining the "
"maximal waiting time is set to %s. This means that prefills "
"waiting longer than %s seconds will always be prioritized. ",
envs_spyre.VLLM_SPYRE_MAX_WAITING_TIME_PREFILL,
envs_spyre.VLLM_SPYRE_MAX_WAITING_TIME_PREFILL)

@classmethod
def use_all_gather(cls) -> bool:
"""
Expand Down
24 changes: 23 additions & 1 deletion vllm_spyre/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
import os
import time
from collections import deque
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -156,6 +157,9 @@ def __init__(self, *args, **kwargs) -> None:
assert self.max_batch_tkv_limit != '-1', (
"Expecting the env var VLLM_DT_MAX_BATCH_TKV_LIMIT to be set in "
"platform.py")
# if batch_is_locked: finish current decode batch to serve a request
# that waited for longer than VLLM_SPYRE_MAX_WAITING_TIME_PREFILL
self.batch_is_locked = False

def update_from_output(
self,
Expand All @@ -179,14 +183,20 @@ def schedule(self) -> "SchedulerOutput":
To avoid additional specialization, some requests are held back from the
base scheduler but are restored after.
"""
# unlock the current decode batch if no requests are in running queue
if len(self.running) == 0:
self.batch_is_locked = False
logger.debug("Unlocking the current decode batch as no requests "
"are in running queue")
# First purge the full waiting queue into our holdback queue, preserving
# priority
while self.waiting:
self.holdback_queue.append(self.waiting.popleft())

# Check if new requests can be scheduled.
while self.holdback_queue:
if self.can_schedule(self.holdback_queue[0]):
if not self.batch_is_locked and self.can_schedule(
self.holdback_queue[0]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be tested directly in the can_schedule() function? Maybe it can be the first tested condition, and return False directly if wrong

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it could also go on top of can_schedule(), true. Having it here is less code and avoids jumping into can_schedule() if we already know it is gonna return False. The way I interpret can_schedule(req) is a check whether request req could be scheduled with the current decode batch. The flag batch_is_locked was set by yet another request (not by req nor by any request in self.running). So this case can be treated outside of can_schedule(). But my opinion is not very strong here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you decide, my thought what that the decision of scheduling or not was entirely in one place. but I see your point also

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having it here is less code and avoids jumping into can_schedule() if we already know it is gonna return False

Couldn't it just be the first thing we check in can_schedule ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved it

# Add request to the waiting queue
self.waiting.append(self.holdback_queue.popleft())
else:
Expand Down Expand Up @@ -226,6 +236,18 @@ def can_schedule(self, request) -> bool:
if len(self.running) + len(self.waiting) == 0:
return True

# scheduling heuristic: maximal waiting (blocking) time for prefill
if envs_spyre.VLLM_SPYRE_MAX_WAITING_TIME_PREFILL > 0:
waiting_time = (time.time() - request.arrival_time)
if waiting_time > envs_spyre.VLLM_SPYRE_MAX_WAITING_TIME_PREFILL:
self.batch_is_locked = True
logger.debug("Request %s waited longer (%ds) than " \
"VLLM_SPYRE_MAX_WAITING_TIME_PREFILL (%ds): locking current" \
" decode batch and schedule this request afterwards.",
request.request_id, waiting_time,
envs_spyre.VLLM_SPYRE_MAX_WAITING_TIME_PREFILL
)

# check that there is space in the current decode batch
cond1 = len(self.running) + len(
self.waiting) < self.max_num_running_reqs
Expand Down