Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# set env vars for torch_sendnn to consume
os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str(
vllm_config.model_config.max_model_len)
# min decode batch size is 2 due to symbolic shape constraint in torch
# min value 2 needed for VLLM_DT_MAX_BATCH_SIZE (compiler constraint)
# Note that we can still have decodes of batch size 1 as the env var
# only concerns the max batch size.
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(
max(vllm_config.scheduler_config.max_num_seqs, 2))

Expand Down
26 changes: 1 addition & 25 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,8 @@ def load_model(self, prompt_lens: Iterable[int],
)

def build_input_batch(self) -> SamplingInputBatch:
# Fix for batch size 1: set input batch to fit 2 requests for warmup,
# and reset input batch to fit max_num_seqs requests after warmup
min_seqs_required = 2 if self.warmup_mode else 1

return SamplingInputBatch(
max_num_reqs=max(min_seqs_required,
self.scheduler_config.max_num_seqs),
max_num_reqs=self.scheduler_config.max_num_seqs,
max_model_len=self.model_config.max_model_len,
device=self.device,
pin_memory=self.pin_memory,
Expand Down Expand Up @@ -810,8 +805,6 @@ def __init__(

def complete_warmup(self) -> None:
super().complete_warmup()
# Fix for batch size 1: need to update the input_batch after the warmup
self.input_batch = self.build_input_batch()
# get the number or pages from the actual Spyre card after the warmup
# and set it accordingly in the model runner and the kv cache size
n_blocks_avail = self._get_num_blocks_available()
Expand Down Expand Up @@ -1112,23 +1105,6 @@ def _prepare_decode(
# mask not needed during decode
mask = None

# add pads for min decode batch size of 2 (Spyre compiler constraint)
if len(cached_request_data.req_ids) == 1:
padd_seq_indices = torch.zeros(1, dtype=torch.bool, device="cpu")
self.model.indices = torch.cat(
(self.model.indices, padd_seq_indices), -1)
assert self.model.indices.size(dim=0) == 2

input_tokens = torch.cat(2 * [input_tokens])
position_ids = torch.cat(2 * [position_ids])
current_tkv_mask = torch.cat(2 * [current_tkv_mask])
left_padded_prompt_mask = torch.cat(2 * [left_padded_prompt_mask])
block_table = torch.cat(2 * [block_table])
slot_mapping = torch.cat(2 * [slot_mapping])

# assert min batch size 2 for decodes (Spyre compiler constraint)
assert len(input_tokens) >= 2

model_inputs = SamplingForwardInputs(
input_tokens=input_tokens,
input_positions=position_ids,
Expand Down
87 changes: 37 additions & 50 deletions vllm_spyre/v1/worker/spyre_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ def load_model(self):
logger.info("load model took %.3fs", load_model_total_t)

def _warmup_spyre_dynamic_size(self, special_token_ids):
# this setting is required to mark a dimension of size 1 as dynamic
# for pytorch >= 2.7.1 (needed to support batch size 1 for decodes)
from torch.fx.experimental import _config as config
config.backed_size_oblivious = True

warmup_start_t = time.time()

Expand All @@ -347,15 +351,14 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
valid_token_ids_tensor = torch.tensor(valid_token_ids,
dtype=torch.long,
device=torch.device("cpu"))
batch_size = 2
prompt_len = 42
num_decode_tokens = 2

# Sample from the valid token ids
warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (batch_size + 1, prompt_len))]
0, len(valid_token_ids_tensor), (2, prompt_len))]

dummy_requests: list[NewRequestData] = [
warmup_req, deploy_req = (
NewRequestData(
req_id="warmup-%d" % (i),
prompt_token_ids=warmup_tokens_tensor[i].tolist(),
Expand All @@ -367,24 +370,21 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
block_ids=[0], # not actually used
num_computed_tokens=0,
lora_request=None,
) for i in range(batch_size + 1)
]
add_dummy_request = dummy_requests.pop(-1)
) for i in range(2))

with _maybe_warmup_context():
self._dynamic_warmup(dummy_requests=dummy_requests,
self._dynamic_warmup(request=warmup_req,
prompt_len=prompt_len,
batch_size=batch_size,
valid_token_ids_tensor=valid_token_ids_tensor)

# warmup_mode completes the graph compilation, but we need to do
# one additional prefill to deploy the compiled program to the device,
# the necessary operations are included in the graph and will be removed
# after this execution
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[add_dummy_request],
scheduled_new_reqs=[deploy_req],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={add_dummy_request.req_id: prompt_len},
num_scheduled_tokens={deploy_req.req_id: prompt_len},
total_num_scheduled_tokens=prompt_len,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
Expand All @@ -396,7 +396,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
)
logger.info("[WARMUP] Deploying to device...")
self.execute_model(scheduler_output)
self._cleanup_model_runner(request=[add_dummy_request])
self._cleanup_model_runner(request=[deploy_req])

model_runner.complete_warmup()

Expand Down Expand Up @@ -550,61 +550,48 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,

def _dynamic_warmup(
self,
dummy_requests: list[NewRequestData],
request: NewRequestData,
prompt_len: int,
batch_size: int,
valid_token_ids_tensor: torch.Tensor,
) -> None:

assert (
_inside_warmup_mode
), "it looks like you are outside the warmup context for warmup"

for i, req in enumerate(dummy_requests):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[req],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req.req_id: prompt_len},
total_num_scheduled_tokens=prompt_len,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
logger.info("[WARMUP] Prefill %d/%d...", i + 1, batch_size)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[request],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={request.req_id: prompt_len},
total_num_scheduled_tokens=prompt_len,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
logger.info("[WARMUP] Prefill...")

self.execute_model(scheduler_output)
self.execute_model(scheduler_output)

# one decode iteration across all sequences
req_ids = []
new_token_ids = []
new_block_ids = []
num_computed_tokens = []
for req in dummy_requests:
req_ids.append(req.req_id)
new_token_ids.append([
valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (1, )).item()]
]) # placeholder token
new_block_ids.append([req.block_ids])
num_computed_tokens.append(prompt_len)
cached_request_data = CachedRequestData(
req_ids=req_ids,
req_ids=[request.req_id],
resumed_from_preemption=False,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
new_token_ids=[[
valid_token_ids_tensor[torch.randint(
0, len(valid_token_ids_tensor), (1, )).item()]
]],
new_block_ids=[request.block_ids],
num_computed_tokens=[prompt_len],
)

scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=cached_request_data,
num_scheduled_tokens={f"warmup-{i}": 1
for i in range(batch_size)},
total_num_scheduled_tokens=batch_size,
num_scheduled_tokens={request.req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
Expand All @@ -615,7 +602,7 @@ def _dynamic_warmup(
)
logger.info("[WARMUP] Decode...")
self.execute_model(scheduler_output)
self._cleanup_model_runner(request=dummy_requests)
self._cleanup_model_runner(request=[request])

def _warmup_model_forward_pass(
self,
Expand Down