Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions examples/offline_inference_spyre_cb_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
import time

from vllm import LLM, SamplingParams

max_tokens1 = 10
max_tokens2 = 5
max_tokens3 = 7
max_tokens = max([max_tokens1, max_tokens2, max_tokens3])
max_num_seqs = 2

os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = '64'
os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max_tokens)
os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = '4'

# defining here to be able to run/debug directly from VSC (not via terminal)
os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = 'eager'
os.environ['VLLM_SPYRE_USE_CB'] = '1'
os.environ['VLLM_USE_V1'] = '1'

# Sample prompts.
template = (
"Below is an instruction that describes a task. Write a response that "
"appropriately completes the request. Be polite in your response to the "
"user.\n\n### Instruction:\n{}\n\n### Response:")

prompt1 = template.format(
"Provide a list of instructions for preparing chicken soup for a family "
"of four.")

prompt2 = template.format("Provide instructions for preparing chicken soup.")

prompt3 = template.format(
"Provide a list of instructions for preparing chicken soup for a family.")

prompts = [
prompt1,
prompt2,
prompt3,
]

# Create a sampling params object.
sampling_params1 = SamplingParams(max_tokens=max_tokens1,
temperature=0.0,
ignore_eos=True)

sampling_params2 = SamplingParams(max_tokens=max_tokens2,
temperature=0.0,
ignore_eos=True)

sampling_params3 = SamplingParams(max_tokens=max_tokens3,
temperature=0.0,
ignore_eos=True)

sampling_params = [
sampling_params1,
sampling_params2,
sampling_params3,
]

# Create an LLM.
llm = LLM(model="/models/llama-194m",
tokenizer="/models/llama-194m",
max_model_len=2048,
block_size=2048,
max_num_seqs=max_num_seqs)

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
print("=============== GENERATE")
t0 = time.time()
outputs = llm.generate(prompts, sampling_params)
print("Time elaspsed for %d tokens is %.2f sec" %
(len(outputs[0].outputs[0].token_ids), time.time() - t0))
print("===============")
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print("===============")
for output in outputs:
print(output.outputs[0])
25 changes: 13 additions & 12 deletions vllm_spyre/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,20 @@ def __init__(
max_decode_length,
)

# horizontal offset in physical KV cache memory block
self.tkv: int = 0

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
masks: torch.Tensor,
is_prompt: bool,
tkv: Optional[int] = None,
active_pages: Optional[list[int]] = None,
) -> torch.Tensor:

if is_prompt:
self.tkv = 0
if not envs_spyre.VLLM_SPYRE_USE_CB:
self.model.past_key_value_states = None
self.tkv = tkv
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we set self.tkv here? It looks like it is not used.


if is_prompt and not envs_spyre.VLLM_SPYRE_USE_CB:
self.model.past_key_value_states = None

extra_kwargs = {}
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder":
Expand Down Expand Up @@ -128,7 +127,7 @@ def forward(
self.model.sample_mask = matrix.unsqueeze(0)

# prefil of batch size 1
logits, self.tkv = self.model(
logits = self.model(
self.model.sample_token_id,
position_ids=self.model.sample_position,
mask=self.model.sample_mask,
Expand All @@ -153,16 +152,18 @@ def forward(
masks[0, :, :] = self.model.sample_mask

# normal prefil or decoding step
logits, self.tkv = self.model(
logits = self.model(
input_ids,
position_ids=positions,
mask=masks,
use_cache=True,
only_last_token=True,
tkv=self.tkv,
active_pages=[i for i in range(input_ids.shape[0])],
#active_pages=[i for i in range(input_ids.shape[0])],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove commented out lines

active_pages=active_pages,
**extra_kwargs,
)

if TESTING_CB and self.tkv >= (6 + 64):
# update sample_token_id, sample_position and sample_mask
self.model.update_sample_inputs(logits=logits[0, :])
Expand Down Expand Up @@ -430,7 +431,7 @@ def forward(
page, :, :tkv, :] = key_value_states[layer][1][
idx, :, :, :] # [1, 8, L, 128]

return logits, tkv + 1
return logits

def update_sample_inputs(
self,
Expand Down Expand Up @@ -494,4 +495,4 @@ def forward(
for tensor in layer:
torch._dynamo.mark_dynamic(tensor, 2)

return logits, tkv + 1
return logits
8 changes: 6 additions & 2 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if envs.VLLM_USE_V1:
# As of 0.7.3 the scheduler for V1 isn't actually pluggable like
# this yet
scheduler_config.scheduler_cls = \
"vllm_spyre.v1.core.scheduler.SpyreScheduler"
if envs_spyre.VLLM_SPYRE_USE_CB:
scheduler_config.scheduler_cls = \
"vllm_spyre.v1.core.scheduler.ContinuousBatchingSpyreScheduler"
else:
scheduler_config.scheduler_cls = \
"vllm_spyre.v1.core.scheduler.SpyreScheduler"
else:
scheduler_config.scheduler_cls = \
"vllm_spyre.core.scheduler.SpyreScheduler"
Expand Down
100 changes: 97 additions & 3 deletions vllm_spyre/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

import vllm_spyre.envs as envs_spyre

try:
from vllm.v1.core.sched.scheduler import Scheduler
except ImportError:
Expand Down Expand Up @@ -75,7 +77,7 @@ def add_request(self, request: Request) -> None:
request.sampling_params = SamplingParams(max_tokens=1)

# delegate to super
super().add_request(request=request)
self.run_add_request(request=request)

def update_from_output(
self,
Expand All @@ -93,7 +95,6 @@ def update_from_output(
def schedule(self) -> SchedulerOutput:
"""This override adds constraints and then delegates most of the work
to the base scheduler"""

# First purge the full waiting queue into our holdback queue, preserving
# priority
while self.waiting:
Expand Down Expand Up @@ -133,7 +134,7 @@ def schedule(self) -> SchedulerOutput:
logger.debug("Scheduling a running batch of %d requests",
len(self.running))

outputs = super().schedule()
outputs = self.run_schedule()
return outputs

def get_num_unfinished_requests(self) -> int:
Expand Down Expand Up @@ -195,3 +196,96 @@ def _reject_from_queue(self,
self.rejected_requests.remove(request.request_id)

return reject_outputs

def run_schedule(self) -> SchedulerOutput:
return super().schedule()

def run_add_request(self, request: Request) -> None:
super().add_request(request=request)


class ContinuousBatchingSpyreScheduler(SpyreScheduler):
""" Support of continuous batching """

def __init__(self, *args, **kwargs) -> None:
# Initialize vLLM scheduler
super().__init__(*args, **kwargs)
# running queue of last decoding step
self.last_running: list[Request] = []
self.total_running: list[Request] = []

def add_request(self, request: Request) -> None:
"""This override rejects requests that exceed max context length"""
if not request.num_prompt_tokens + request.sampling_params.max_tokens\
<= envs_spyre.VLLM_SPYRE_MAX_CONTEXT_LENGTH:
logger.warning(
"Could not add request id %s, prompt length is "
"%d tokens, maximum number of output tokens is %d tokens",
request.request_id,
request.num_prompt_tokens,
request.sampling_params.max_tokens,
)
logger.warning("Could not schedule request id %s",
request.request_id)
# TODO: There are open PRs that should enable raising an error for
# a single request like this, which will gracefully return an error
# for the request, instead of shutting down the engine.
# See https://github.com/vllm-project/vllm/pull/11737
# raise ValueError("Request does not fit any spyre warmup shape")

# For now, we'll insert a dummy request and manually reject it when
# we construct the outputs later
self.rejected_requests.add(request.request_id)
request.prompt_token_ids = [0]
request.num_prompt_tokens = 1
request.sampling_params = SamplingParams(max_tokens=1)

# delegate to super
self.run_add_request(request=request)

def schedule(self) -> "SchedulerOutput":
"""This override adds constraints and then delegates most of the work
to the base scheduler"""
# First purge the full waiting queue into our holdback queue, preserving
# priority
while self.waiting:
self.holdback_queue.append(self.waiting.popleft())

# Check if new requests can be scheduled.
self.total_running = self.last_running + self.running
while self.holdback_queue:
request = self.holdback_queue[0]
if self.can_schedule(request=request):
# Add request to the waiting queue
self.waiting.append(self.holdback_queue.popleft())
else:
# Otherwise, we simply stop here so that the scheduler
# can work with the batch we have
break

if len(self.waiting) > 0:
# If prefill scheduled, save running queue for the next decode step.
# If previous step was also prefill, running queue contains the
# previous prefill sequence.
self.last_running = self.total_running
self.running = []
logger.debug(
"Scheduling a prompt step of %d requests, holding back %d "
"requests", len(self.waiting), len(self.holdback_queue))
else:
# If decode scheduled and previous step was prefil, update running
# queue
self.running = self.total_running
self.last_running = []
logger.debug("Scheduling a decode step of %d requests",
len(self.running))

outputs = self.run_schedule()
return outputs

def can_schedule(self, request: Request) -> bool:
max_prompt_batch_size = 1
# TODO: add additional checks, e.g. max_tokens
return len(self.total_running)+len(self.waiting) <\
self.max_num_running_reqs and\
len(self.waiting) < max_prompt_batch_size
Loading