Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 85 additions & 89 deletions vllm_spyre/v1/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,100 +797,94 @@ def _prepare_prompt(
self,
new_requests: list[NewRequestData],
) -> ModelForwardInputs:
assert len(new_requests) > 0
input_token_list: list[torch.Tensor] = []
# currently all prefills are of batch size 1
assert len(new_requests) == 1

# ceil division to pad to next block boundary
new_batch = len(self.req_ids2blocks) == 0
max_prompt_len = max([len(r.prompt_token_ids) for r in new_requests])
if not new_batch:
assert max_prompt_len <= self.tkv
n = max_prompt_len if new_batch else self.tkv
block_padding = math.ceil(n / self.block_size) * self.block_size
if new_batch:
self.tkv = block_padding
request = new_requests[0]
req_id = request.req_id
prompt_token_ids = request.prompt_token_ids
sampling_params = request.sampling_params
is_new_batch = len(self.req_ids2blocks) == 0
prompt_len = len(prompt_token_ids)

# Internal state is managed here.
slot_mapping = []
# make sure that the prompt length is at most the current tkv
# if it joins an existing decode batch
if not is_new_batch:
assert prompt_len <= self.tkv

self.prefill_batch.clear_requests()

for request_data in new_requests:
# Reserve the max blocks used to serve current sequence
new_tokens = (request_data.sampling_params.max_tokens
if request_data.sampling_params is not None else 0)
n = self.tkv + new_tokens - 1
n_reserved_blocks = math.ceil(n / self.block_size)
self.req_ids2reserved_blocks[
request_data.req_id] = n_reserved_blocks

# retrieve initial (unpadded) tokens
prompt_tokens = request_data.prompt_token_ids
left_padding = self.tkv - len(prompt_tokens)
input_token_list.append(
torch.tensor(prompt_tokens,
dtype=torch.long,
device=torch.device("cpu")))

# filling block table and slot mapping
block_table_i = []
slot_mapping_i = []
for pos_i in range(block_padding):
if pos_i % self.block_size == 0:
block_number = self.block_pool.popleft()
block_table_i.append(block_number)
block_offset = pos_i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping_i.append(slot)
self.req_ids2blocks[request_data.req_id] = deque(block_table_i)
slot_mapping.append(slot_mapping_i)

# Add new requests to the cached states.
req_id = request_data.req_id
sampling_params = request_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
# right padding to the next block boundary (ceil division)
# -> prefills must to be multiples of the block size (Spyre constraint)
n = prompt_len if is_new_batch else self.tkv
block_padding = math.ceil(n / self.block_size) * self.block_size

req_state = CachedRequestState(
req_id=req_id,
prompt_token_ids=request_data.prompt_token_ids,
sampling_params=sampling_params,
generator=generator,
output_token_ids=[],
left_padding=left_padding)
self.requests[req_id] = req_state
self.input_batch.add_request(req_state)
self.prefill_batch.add_request(req_state)
# set the tkv to the block padding if starting a new decode batch
self.tkv = block_padding if is_new_batch else self.tkv

# left padding to align the prefill sequence with the tkv of the
# current decode batch (Spyre constraint)
left_padding = self.tkv - prompt_len

# Reserve the maximal number of blocks used to serve current sequence
Copy link
Collaborator

@sducouedic sducouedic Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe? I was confused by "maximal" and "current"

Suggested change
# Reserve the maximal number of blocks used to serve current sequence
# Reserve the number of blocks required to serve this new sequence

new_tokens = (sampling_params.max_tokens
if sampling_params is not None else 0)
n = self.tkv + new_tokens - 1
n_reserved_blocks = math.ceil(n / self.block_size)
self.req_ids2reserved_blocks[req_id] = n_reserved_blocks

# filling block table and slot mapping
blocks = []
slots = []
for pos_i in range(block_padding):
if pos_i % self.block_size == 0:
block_number = self.block_pool.popleft()
blocks.append(block_number)
block_offset = pos_i % self.block_size
slot = block_number * self.block_size + block_offset
slots.append(slot)
self.req_ids2blocks[req_id] = deque(blocks)

# Add new request to the cached states.
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None

req_state = CachedRequestState(req_id=req_id,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
generator=generator,
output_token_ids=[],
left_padding=left_padding)
self.requests[req_id] = req_state
self.input_batch.add_request(req_state)
self.prefill_batch.add_request(req_state)

# Refresh sampling metadata after all request are added to the batch
self.input_batch.refresh_metadata()
self.prefill_batch.refresh_metadata()

# TODO: Review this in the future
# prefills are always of batch size 1 for this milestone
# Also, we added an input batch just for that.
actual_batch_size = len(input_token_list)
assert actual_batch_size == 1
self.model.indices = torch.ones(actual_batch_size,
dtype=torch.bool,
device='cpu')
# construct tensor from list
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64)
block_table = None
self.model.indices = torch.ones(1, dtype=torch.bool, device='cpu')
slot_mapping = torch.tensor([slots], dtype=torch.int64)
prompt_token_ids_tensor = torch.tensor(prompt_token_ids,
dtype=torch.long,
device=torch.device("cpu"))

# get position ids and attention mask
# applies left padding to align with tkv of current decode batch
# and right padding to align with the next block boundary
input_tokens, position_ids, mask =\
self.pad_input_ids(input_token_list, min_pad_length=block_padding)
input_tokens, position_ids, mask = self.pad_input_ids(
[prompt_token_ids_tensor], min_pad_length=block_padding)
mask = mask.unsqueeze(1)

# not needed for prefill
current_tkv_mask = None
# left padding info is stored in CachedRequestState of self.requests
left_padded_prompt_mask = None
# block table is stored in self.req_ids2blocks (only passed for decode)
block_table = None

model_inputs = ModelForwardInputs(
input_tokens=input_tokens,
Expand Down Expand Up @@ -918,9 +912,6 @@ def _prepare_decode(
block_table = []
slot_mapping = []
left_padded_prompt_mask = []
self.model.indices = torch.ones(len(cached_request_data.req_ids),
dtype=torch.bool,
device="cpu")

assert len(self.input_batch.req_id_to_index) == len(
cached_request_data.req_ids)
Expand All @@ -937,7 +928,7 @@ def _prepare_decode(
# or jump decoding?

req_state: CachedRequestState = self.requests[req_id]
# adding new blocks if needed
# adding new blocks for current sequence if needed
if self.tkv // self.block_size + 1 > len(
self.req_ids2blocks[req_id]):
self.req_ids2blocks[req_id].append(self.block_pool.popleft())
Expand All @@ -948,37 +939,42 @@ def _prepare_decode(
offset = self.tkv % self.block_size
slot = [start_slot + offset]
slot_mapping.append(slot)
output_token_ids = req_state.output_token_ids
generation_token = output_token_ids[-1]

# input token and position of the token generated in the last step
generation_token = req_state.output_token_ids[-1]
input_tokens.append([generation_token])
seq_len = cached_request_data.num_computed_tokens[
cached_reqs_map[req_id]]
input_positions.append([seq_len])

# retrieve left padding information stored during prefill and
# updated when calling reduce_left_padding()
left_padded_prompt_mask.append(req_state.left_padding)

# update tkv
self.tkv = self.tkv + 1

# construct tensors from lists
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
position_ids = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
current_tkv_mask = torch.tensor([self.tkv] * len(input_tokens),
dtype=torch.int64)
left_padded_prompt_mask = torch.tensor(left_padded_prompt_mask,
dtype=torch.long,
device=self.device)
# construct tensors from lists
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64)
block_table = torch.tensor(block_table, dtype=torch.int64)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64)
self.model.indices = torch.ones(len(cached_request_data.req_ids),
dtype=torch.bool,
device="cpu")

# not needed for decode
# mask not needed during decode
mask = None

# update tkv
self.tkv = self.tkv + 1

current_tkv_mask = torch.tensor([self.tkv] * len(input_tokens),
dtype=torch.int64)

# add pads for min decode batch size of 2 (Spyre compiler constraint)
if len(cached_request_data.req_ids) == 1:
padd_seq_indices = torch.zeros(1, dtype=torch.bool, device="cpu")
Expand Down
Loading