From 095cccdc9f0ead4ce568a5f504bbc794829224f2 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Fri, 25 Jul 2025 12:49:57 +0000 Subject: [PATCH 1/4] refactor _prepare_prompt Signed-off-by: Yannick Schnider --- vllm_spyre/v1/worker/spyre_model_runner.py | 142 ++++++++++----------- 1 file changed, 68 insertions(+), 74 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index a6122ec9c..80212bf2f 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -797,100 +797,94 @@ def _prepare_prompt( self, new_requests: list[NewRequestData], ) -> ModelForwardInputs: - assert len(new_requests) > 0 - input_token_list: list[torch.Tensor] = [] + # currently all prefills are of batch size 1 + assert len(new_requests) == 1 - # ceil division to pad to next block boundary - new_batch = len(self.req_ids2blocks) == 0 - max_prompt_len = max([len(r.prompt_token_ids) for r in new_requests]) - if not new_batch: - assert max_prompt_len <= self.tkv - n = max_prompt_len if new_batch else self.tkv - block_padding = math.ceil(n / self.block_size) * self.block_size - if new_batch: - self.tkv = block_padding + request = new_requests[0] + req_id = request.req_id + prompt_token_ids = request.prompt_token_ids + sampling_params = request.sampling_params + is_new_batch = len(self.req_ids2blocks) == 0 + prompt_len = len(prompt_token_ids) - # Internal state is managed here. - slot_mapping = [] + # make sure that the prompt length is at most the current tkv + # if it joins an existing decode batch + if not is_new_batch: + assert prompt_len <= self.tkv self.prefill_batch.clear_requests() - for request_data in new_requests: - # Reserve the max blocks used to serve current sequence - new_tokens = (request_data.sampling_params.max_tokens - if request_data.sampling_params is not None else 0) - n = self.tkv + new_tokens - 1 - n_reserved_blocks = math.ceil(n / self.block_size) - self.req_ids2reserved_blocks[ - request_data.req_id] = n_reserved_blocks - - # retrieve initial (unpadded) tokens - prompt_tokens = request_data.prompt_token_ids - left_padding = self.tkv - len(prompt_tokens) - input_token_list.append( - torch.tensor(prompt_tokens, - dtype=torch.long, - device=torch.device("cpu"))) - - # filling block table and slot mapping - block_table_i = [] - slot_mapping_i = [] - for pos_i in range(block_padding): - if pos_i % self.block_size == 0: - block_number = self.block_pool.popleft() - block_table_i.append(block_number) - block_offset = pos_i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping_i.append(slot) - self.req_ids2blocks[request_data.req_id] = deque(block_table_i) - slot_mapping.append(slot_mapping_i) - - # Add new requests to the cached states. - req_id = request_data.req_id - sampling_params = request_data.sampling_params - if sampling_params.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) - generator.manual_seed(sampling_params.seed) - else: - generator = None + # right padding to the next block boundary (ceil division) + # -> prefills must to be multiples of the block size (Spyre constraint) + n = prompt_len if is_new_batch else self.tkv + block_padding = math.ceil(n / self.block_size) * self.block_size - req_state = CachedRequestState( - req_id=req_id, - prompt_token_ids=request_data.prompt_token_ids, - sampling_params=sampling_params, - generator=generator, - output_token_ids=[], - left_padding=left_padding) - self.requests[req_id] = req_state - self.input_batch.add_request(req_state) - self.prefill_batch.add_request(req_state) + # set the tkv to the block padding if starting a new decode batch + self.tkv = block_padding if is_new_batch else self.tkv + + # left padding to align the prefill sequence with the tkv of the + # current decode batch (Spyre constraint) + left_padding = self.tkv - prompt_len + + # Reserve the maximal number of blocks used to serve current sequence + new_tokens = (sampling_params.max_tokens + if sampling_params is not None else 0) + n = self.tkv + new_tokens - 1 + n_reserved_blocks = math.ceil(n / self.block_size) + self.req_ids2reserved_blocks[req_id] = n_reserved_blocks + + # filling block table and slot mapping + blocks = [] + slots = [] + for pos_i in range(block_padding): + if pos_i % self.block_size == 0: + block_number = self.block_pool.popleft() + blocks.append(block_number) + block_offset = pos_i % self.block_size + slot = block_number * self.block_size + block_offset + slots.append(slot) + self.req_ids2blocks[req_id] = deque(blocks) + + # Add new request to the cached states. + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + else: + generator = None + + req_state = CachedRequestState(req_id=req_id, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + generator=generator, + output_token_ids=[], + left_padding=left_padding) + self.requests[req_id] = req_state + self.input_batch.add_request(req_state) + self.prefill_batch.add_request(req_state) # Refresh sampling metadata after all request are added to the batch self.input_batch.refresh_metadata() self.prefill_batch.refresh_metadata() - # TODO: Review this in the future - # prefills are always of batch size 1 for this milestone - # Also, we added an input batch just for that. - actual_batch_size = len(input_token_list) - assert actual_batch_size == 1 - self.model.indices = torch.ones(actual_batch_size, - dtype=torch.bool, - device='cpu') - # construct tensor from list - slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64) - block_table = None + self.model.indices = torch.ones(1, dtype=torch.bool, device='cpu') + slot_mapping = torch.tensor([slots], dtype=torch.int64) + prompt_token_ids_tensor = torch.tensor(prompt_token_ids, + dtype=torch.long, + device=torch.device("cpu")) # get position ids and attention mask # applies left padding to align with tkv of current decode batch # and right padding to align with the next block boundary - input_tokens, position_ids, mask =\ - self.pad_input_ids(input_token_list, min_pad_length=block_padding) + input_tokens, position_ids, mask = self.pad_input_ids( + [prompt_token_ids_tensor], min_pad_length=block_padding) mask = mask.unsqueeze(1) # not needed for prefill current_tkv_mask = None + # left padding info is stored in CachedRequestState of self.requests left_padded_prompt_mask = None + # block table is stored in self.req_ids2blocks (only passed for decode) + block_table = None model_inputs = ModelForwardInputs( input_tokens=input_tokens, From 7c5f3e41013eadbe83f70c163af012b333aeb69e Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Fri, 25 Jul 2025 12:58:18 +0000 Subject: [PATCH 2/4] refactor _prepare_decode (mostly reordering) Signed-off-by: Yannick Schnider --- vllm_spyre/v1/worker/spyre_model_runner.py | 23 +++++++++++----------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 80212bf2f..2a75f7df5 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -912,9 +912,6 @@ def _prepare_decode( block_table = [] slot_mapping = [] left_padded_prompt_mask = [] - self.model.indices = torch.ones(len(cached_request_data.req_ids), - dtype=torch.bool, - device="cpu") assert len(self.input_batch.req_id_to_index) == len( cached_request_data.req_ids) @@ -951,28 +948,30 @@ def _prepare_decode( left_padded_prompt_mask.append(req_state.left_padding) + # update tkv + self.tkv = self.tkv + 1 + + # construct tensors from lists input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) position_ids = torch.tensor(input_positions, dtype=torch.long, device=self.device) + current_tkv_mask = torch.tensor([self.tkv] * len(input_tokens), + dtype=torch.int64) left_padded_prompt_mask = torch.tensor(left_padded_prompt_mask, dtype=torch.long, device=self.device) - # construct tensors from lists - slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64) block_table = torch.tensor(block_table, dtype=torch.int64) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64) + self.model.indices = torch.ones(len(cached_request_data.req_ids), + dtype=torch.bool, + device="cpu") - # not needed for decode + # not needed for mask during decode mask = None - # update tkv - self.tkv = self.tkv + 1 - - current_tkv_mask = torch.tensor([self.tkv] * len(input_tokens), - dtype=torch.int64) - # add pads for min decode batch size of 2 (Spyre compiler constraint) if len(cached_request_data.req_ids) == 1: padd_seq_indices = torch.zeros(1, dtype=torch.bool, device="cpu") From f8d5152d467b34ef479155cad79086a1f85b9dd2 Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Fri, 25 Jul 2025 13:06:04 +0000 Subject: [PATCH 3/4] adding some comments to prepare_decode Signed-off-by: Yannick Schnider --- vllm_spyre/v1/worker/spyre_model_runner.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 2a75f7df5..666c4d9dc 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -928,7 +928,7 @@ def _prepare_decode( # or jump decoding? req_state: CachedRequestState = self.requests[req_id] - # adding new blocks if needed + # adding new blocks for current sequence if needed if self.tkv // self.block_size + 1 > len( self.req_ids2blocks[req_id]): self.req_ids2blocks[req_id].append(self.block_pool.popleft()) @@ -939,13 +939,16 @@ def _prepare_decode( offset = self.tkv % self.block_size slot = [start_slot + offset] slot_mapping.append(slot) - output_token_ids = req_state.output_token_ids - generation_token = output_token_ids[-1] + + # input token and position of the token generated in the last step + generation_token = req_state.output_token_ids[-1] input_tokens.append([generation_token]) seq_len = cached_request_data.num_computed_tokens[ cached_reqs_map[req_id]] input_positions.append([seq_len]) + # retrieve left padding information stored during prefill and + # updated when calling reduce_left_padding() left_padded_prompt_mask.append(req_state.left_padding) # update tkv @@ -969,7 +972,7 @@ def _prepare_decode( dtype=torch.bool, device="cpu") - # not needed for mask during decode + # mask not needed during decode mask = None # add pads for min decode batch size of 2 (Spyre compiler constraint) From e9eee49763884bf374811ddf8a875f87ce553ffb Mon Sep 17 00:00:00 2001 From: Yannick Schnider Date: Fri, 25 Jul 2025 16:05:10 +0000 Subject: [PATCH 4/4] make comment clearer Signed-off-by: Yannick Schnider --- vllm_spyre/v1/worker/spyre_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 666c4d9dc..6ef9025c2 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -826,7 +826,8 @@ def _prepare_prompt( # current decode batch (Spyre constraint) left_padding = self.tkv - prompt_len - # Reserve the maximal number of blocks used to serve current sequence + # Reserve the number of blocks that this new sequence requires in the + # worst case (it might always stop early by producing the EOS token) new_tokens = (sampling_params.max_tokens if sampling_params is not None else 0) n = self.tkv + new_tokens - 1