diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 93e3378bd..502a94b59 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -200,7 +200,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # set env vars for torch_sendnn to consume os.environ["VLLM_DT_MAX_CONTEXT_LEN"] = str( vllm_config.model_config.max_model_len) - # min decode batch size is 2 due to symbolic shape constraint in torch + # min value 2 needed for VLLM_DT_MAX_BATCH_SIZE (compiler constraint) + # Note that we can still have decodes of batch size 1 as the env var + # only concerns the max batch size. os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str( max(vllm_config.scheduler_config.max_num_seqs, 2)) diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index eeb51086b..f770e4204 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -291,13 +291,8 @@ def load_model(self, prompt_lens: Iterable[int], ) def build_input_batch(self) -> SamplingInputBatch: - # Fix for batch size 1: set input batch to fit 2 requests for warmup, - # and reset input batch to fit max_num_seqs requests after warmup - min_seqs_required = 2 if self.warmup_mode else 1 - return SamplingInputBatch( - max_num_reqs=max(min_seqs_required, - self.scheduler_config.max_num_seqs), + max_num_reqs=self.scheduler_config.max_num_seqs, max_model_len=self.model_config.max_model_len, device=self.device, pin_memory=self.pin_memory, @@ -810,8 +805,6 @@ def __init__( def complete_warmup(self) -> None: super().complete_warmup() - # Fix for batch size 1: need to update the input_batch after the warmup - self.input_batch = self.build_input_batch() # get the number or pages from the actual Spyre card after the warmup # and set it accordingly in the model runner and the kv cache size n_blocks_avail = self._get_num_blocks_available() @@ -1112,23 +1105,6 @@ def _prepare_decode( # mask not needed during decode mask = None - # add pads for min decode batch size of 2 (Spyre compiler constraint) - if len(cached_request_data.req_ids) == 1: - padd_seq_indices = torch.zeros(1, dtype=torch.bool, device="cpu") - self.model.indices = torch.cat( - (self.model.indices, padd_seq_indices), -1) - assert self.model.indices.size(dim=0) == 2 - - input_tokens = torch.cat(2 * [input_tokens]) - position_ids = torch.cat(2 * [position_ids]) - current_tkv_mask = torch.cat(2 * [current_tkv_mask]) - left_padded_prompt_mask = torch.cat(2 * [left_padded_prompt_mask]) - block_table = torch.cat(2 * [block_table]) - slot_mapping = torch.cat(2 * [slot_mapping]) - - # assert min batch size 2 for decodes (Spyre compiler constraint) - assert len(input_tokens) >= 2 - model_inputs = SamplingForwardInputs( input_tokens=input_tokens, input_positions=position_ids, diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 08e51741d..d1013831d 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -330,6 +330,10 @@ def load_model(self): logger.info("load model took %.3fs", load_model_total_t) def _warmup_spyre_dynamic_size(self, special_token_ids): + # this setting is required to mark a dimension of size 1 as dynamic + # for pytorch >= 2.7.1 (needed to support batch size 1 for decodes) + from torch.fx.experimental import _config as config + config.backed_size_oblivious = True warmup_start_t = time.time() @@ -347,15 +351,14 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): valid_token_ids_tensor = torch.tensor(valid_token_ids, dtype=torch.long, device=torch.device("cpu")) - batch_size = 2 prompt_len = 42 num_decode_tokens = 2 # Sample from the valid token ids warmup_tokens_tensor = valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (batch_size + 1, prompt_len))] + 0, len(valid_token_ids_tensor), (2, prompt_len))] - dummy_requests: list[NewRequestData] = [ + warmup_req, deploy_req = ( NewRequestData( req_id="warmup-%d" % (i), prompt_token_ids=warmup_tokens_tensor[i].tolist(), @@ -367,14 +370,11 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): block_ids=[0], # not actually used num_computed_tokens=0, lora_request=None, - ) for i in range(batch_size + 1) - ] - add_dummy_request = dummy_requests.pop(-1) + ) for i in range(2)) with _maybe_warmup_context(): - self._dynamic_warmup(dummy_requests=dummy_requests, + self._dynamic_warmup(request=warmup_req, prompt_len=prompt_len, - batch_size=batch_size, valid_token_ids_tensor=valid_token_ids_tensor) # warmup_mode completes the graph compilation, but we need to do @@ -382,9 +382,9 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): # the necessary operations are included in the graph and will be removed # after this execution scheduler_output = SchedulerOutput( - scheduled_new_reqs=[add_dummy_request], + scheduled_new_reqs=[deploy_req], scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={add_dummy_request.req_id: prompt_len}, + num_scheduled_tokens={deploy_req.req_id: prompt_len}, total_num_scheduled_tokens=prompt_len, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, @@ -396,7 +396,7 @@ def _warmup_spyre_dynamic_size(self, special_token_ids): ) logger.info("[WARMUP] Deploying to device...") self.execute_model(scheduler_output) - self._cleanup_model_runner(request=[add_dummy_request]) + self._cleanup_model_runner(request=[deploy_req]) model_runner.complete_warmup() @@ -550,9 +550,8 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, def _dynamic_warmup( self, - dummy_requests: list[NewRequestData], + request: NewRequestData, prompt_len: int, - batch_size: int, valid_token_ids_tensor: torch.Tensor, ) -> None: @@ -560,51 +559,39 @@ def _dynamic_warmup( _inside_warmup_mode ), "it looks like you are outside the warmup context for warmup" - for i, req in enumerate(dummy_requests): - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[req], - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={req.req_id: prompt_len}, - total_num_scheduled_tokens=prompt_len, - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None, - ) - logger.info("[WARMUP] Prefill %d/%d...", i + 1, batch_size) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[request], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={request.req_id: prompt_len}, + total_num_scheduled_tokens=prompt_len, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + logger.info("[WARMUP] Prefill...") - self.execute_model(scheduler_output) + self.execute_model(scheduler_output) - # one decode iteration across all sequences - req_ids = [] - new_token_ids = [] - new_block_ids = [] - num_computed_tokens = [] - for req in dummy_requests: - req_ids.append(req.req_id) - new_token_ids.append([ - valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (1, )).item()] - ]) # placeholder token - new_block_ids.append([req.block_ids]) - num_computed_tokens.append(prompt_len) cached_request_data = CachedRequestData( - req_ids=req_ids, + req_ids=[request.req_id], resumed_from_preemption=False, - new_token_ids=new_token_ids, - new_block_ids=new_block_ids, - num_computed_tokens=num_computed_tokens, + new_token_ids=[[ + valid_token_ids_tensor[torch.randint( + 0, len(valid_token_ids_tensor), (1, )).item()] + ]], + new_block_ids=[request.block_ids], + num_computed_tokens=[prompt_len], ) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=cached_request_data, - num_scheduled_tokens={f"warmup-{i}": 1 - for i in range(batch_size)}, - total_num_scheduled_tokens=batch_size, + num_scheduled_tokens={request.req_id: 1}, + total_num_scheduled_tokens=1, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, @@ -615,7 +602,7 @@ def _dynamic_warmup( ) logger.info("[WARMUP] Decode...") self.execute_model(scheduler_output) - self._cleanup_model_runner(request=dummy_requests) + self._cleanup_model_runner(request=[request]) def _warmup_model_forward_pass( self,