diff --git a/examples/offline_inference_spyre_cb.py b/examples/offline_inference_spyre_cb.py deleted file mode 100644 index 223223385..000000000 --- a/examples/offline_inference_spyre_cb.py +++ /dev/null @@ -1,53 +0,0 @@ -import os -import time - -from vllm import LLM, SamplingParams - -max_tokens = 15 -early_stop = 5 - -os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = '64' -os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max_tokens) -os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = '4' - -# defining here to be able to run/debug directly from VSC (not via terminal) -os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = 'eager' -os.environ['MASTER_ADDR'] = 'localhost' -os.environ['MASTER_PORT'] = '12355' -os.environ['VLLM_SPYRE_USE_CB'] = '1' -os.environ['VLLM_SPYRE_MAX_BATCH_SIZE'] = '4' -os.environ['VLLM_SPYRE_MAX_CONTEXT_LENGTH'] = '2048' - -# Sample prompts. -template = ( - "Below is an instruction that describes a task. Write a response that " - "appropriately completes the request. Be polite in your response to the " - "user.\n\n### Instruction:\n{}\n\n### Response:") -prompt1 = template.format( - "Provide a list of instructions for preparing chicken soup for a family " - "of four.") -prompts = [prompt1, prompt1, prompt1, prompt1] - -# Create a sampling params object: first sequence will terminate early and will -# replaced with continuous batching -sampling_params = [] -for i in range(4): - sampling_params.append( - SamplingParams(max_tokens=early_stop if i == 0 else max_tokens, - temperature=0.0, - ignore_eos=True)) -# Create an LLM -llm = LLM(model="models/llama-194m", - tokenizer="models/llama-194m", - max_model_len=2048, - block_size=2048) - -# Generate texts from the prompts. The output is a list of RequestOutput objects -# that contain the prompt, generated text, and other information. -print("=============== GENERATE") -t0 = time.time() -outputs = llm.generate(prompts, sampling_params) -print("Time elaspsed for %d tokens is %.2f sec" % - (len(outputs[0].outputs[0].token_ids), time.time() - t0)) -for output in outputs: - print(output.outputs[0]) diff --git a/examples/offline_inference_spyre_cb_test.py b/examples/offline_inference_spyre_cb_test.py new file mode 100644 index 000000000..30ca37e1f --- /dev/null +++ b/examples/offline_inference_spyre_cb_test.py @@ -0,0 +1,83 @@ +import os +import time + +from vllm import LLM, SamplingParams + +max_tokens1 = 10 +max_tokens2 = 5 +max_tokens3 = 7 +max_tokens = max([max_tokens1, max_tokens2, max_tokens3]) +max_num_seqs = 2 # defines max batch size + +os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = '64' +os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max_tokens) + +# defining here to be able to run/debug directly from VSC (not via terminal) +os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = 'eager' +os.environ['VLLM_SPYRE_USE_CB'] = '1' +os.environ['VLLM_USE_V1'] = '1' + +os.environ['VLLM_SPYRE_MAX_CONTEXT_LENGTH'] = '2048' +os.environ['VLLM_SPYRE_MAX_BATCH_SIZE'] = str(max_num_seqs) + +# Sample prompts. +template = ( + "Below is an instruction that describes a task. Write a response that " + "appropriately completes the request. Be polite in your response to the " + "user.\n\n### Instruction:\n{}\n\n### Response:") + +prompt1 = template.format( + "Provide a list of instructions for preparing chicken soup for a family " + "of four.") + +prompt2 = template.format("Provide instructions for preparing chicken soup.") + +prompt3 = template.format( + "Provide a list of instructions for preparing chicken soup for a family.") + +prompts = [ + prompt1, + prompt2, + prompt3, +] + +# Create a sampling params object. +sampling_params1 = SamplingParams(max_tokens=max_tokens1, + temperature=0.0, + ignore_eos=True) + +sampling_params2 = SamplingParams(max_tokens=max_tokens2, + temperature=0.0, + ignore_eos=True) + +sampling_params3 = SamplingParams(max_tokens=max_tokens3, + temperature=0.0, + ignore_eos=True) + +sampling_params = [ + sampling_params1, + sampling_params2, + sampling_params3, +] + +# Create an LLM. +llm = LLM(model="/models/llama-194m", + tokenizer="/models/llama-194m", + max_model_len=2048, + block_size=2048) + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +print("=============== GENERATE") +t0 = time.time() +outputs = llm.generate(prompts, sampling_params) +print("Time elaspsed for %d tokens is %.2f sec" % + (len(outputs[0].outputs[0].token_ids), time.time() - t0)) +print("===============") +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +print("===============") +for output in outputs: + print(output.outputs[0]) diff --git a/vllm_spyre/model_executor/model_loader/spyre.py b/vllm_spyre/model_executor/model_loader/spyre.py index be0276345..1854386f3 100644 --- a/vllm_spyre/model_executor/model_loader/spyre.py +++ b/vllm_spyre/model_executor/model_loader/spyre.py @@ -33,10 +33,6 @@ logger = init_logger(__name__) -# for testing use offline_inference_spyre_cb.py -TESTING_CB = False -PRINTS_CB = False - class SpyreCausalLM(nn.Module): @@ -68,21 +64,18 @@ def __init__( max_decode_length, ) - # horizontal offset in physical KV cache memory block - self.tkv: int = 0 - def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, masks: torch.Tensor, is_prompt: bool, + tkv: Optional[int] = None, + active_pages: Optional[list[int]] = None, ) -> torch.Tensor: - if is_prompt: - self.tkv = 0 - if not envs_spyre.VLLM_SPYRE_USE_CB: - self.model.past_key_value_states = None + if is_prompt and not envs_spyre.VLLM_SPYRE_USE_CB: + self.model.past_key_value_states = None extra_kwargs = {} if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder": @@ -90,85 +83,17 @@ def forward( # cpu impl when padding too much extra_kwargs["attn_algorithm"] = "math" - # testing only: prefil after 5 decodes - if TESTING_CB and self.tkv == (5 + 64): - # define sample prompt - input_ids_list = [ - 128000, 39314, 374, 459, 7754, 430, 16964, 264, 3465, 13, 9842, - 264, 2077, 430, 36001, 45695, 279, 1715, 13, 2893, 48887, 304, - 701, 2077, 311, 279, 1217, 382, 14711, 30151, 512, 61524, 264, - 1160, 315, 11470, 369, 20646, 16553, 19724, 369, 264, 3070, - 315, 3116, 382, 14711, 6075, 25 - ] - prompt_len = len(input_ids_list) - tkv_insert = self.tkv - 1 - padding_len = tkv_insert - prompt_len - - # construct token and position ids for the sample prompt - self.model.sample_token_id = torch.tensor( - [0] * padding_len + input_ids_list).unsqueeze(0) - self.model.sample_position = torch.tensor( - [0] * padding_len + [i - for i in range(prompt_len)]).unsqueeze(0) - - # Construct attention mask for the sample prompt - n = tkv_insert - m = prompt_len - - top = torch.cat( - (torch.tril(torch.ones(n - m, n - m)), torch.zeros(n - m, m)), - dim=1) - bottom = torch.cat( - (torch.zeros(m, n - m), torch.tril(torch.ones(m, m))), dim=1) - matrix = torch.cat((top, bottom), dim=0) - - matrix = matrix.masked_fill(matrix == 1, - 0).masked_fill(matrix == 0, - float('-inf')) - self.model.sample_mask = matrix.unsqueeze(0) - - # prefil of batch size 1 - logits, self.tkv = self.model( - self.model.sample_token_id, - position_ids=self.model.sample_position, - mask=self.model.sample_mask, - use_cache=True, - only_last_token=True, - tkv=0, - active_pages=[0], - **extra_kwargs, - ) - - if PRINTS_CB: - print('inserted sequence token id: ', - torch.argmax(logits[0, :])) - - # update sample_token_id, sample_position and sample_mask - self.model.update_sample_inputs(logits=logits[0, :]) - - if TESTING_CB and self.tkv >= (5 + 64): - # set input_ids, positions, masks for inserted sequence - input_ids[0, :] = self.model.sample_token_id - positions[0, :] = self.model.sample_position - masks[0, :, :] = self.model.sample_mask - # normal prefil or decoding step - logits, self.tkv = self.model( + logits = self.model( input_ids, position_ids=positions, mask=masks, use_cache=True, only_last_token=True, - tkv=self.tkv, - active_pages=[i for i in range(input_ids.shape[0])], + tkv=tkv, + active_pages=active_pages, **extra_kwargs, ) - if TESTING_CB and self.tkv >= (6 + 64): - # update sample_token_id, sample_position and sample_mask - self.model.update_sample_inputs(logits=logits[0, :]) - if PRINTS_CB: - print('inserted sequence token id: ', - torch.argmax(logits[0, :])) # removing finished or padded sequences logits = logits[self.indices] @@ -368,12 +293,6 @@ def __init__( for i in range(num_layers) ] - # variables used for testing insertion of sample input - if TESTING_CB: - self.sample_token_id: torch.Tensor = torch.empty((1, 1)) - self.sample_position: torch.Tensor = torch.empty((1, 1)) - self.sample_mask: torch.Tensor = torch.empty((1, 1, 1)) - def forward( self, input_ids: torch.Tensor, @@ -386,10 +305,6 @@ def forward( **extra_kwargs, ) -> torch.Tensor: - if PRINTS_CB: - print("tkv", tkv) - print("active_pages", active_pages) - # read-out (dynamic) kv_cache for decoding steps only, # for prefills kv_cache = None if tkv == 0: # prefil @@ -430,19 +345,7 @@ def forward( page, :, :tkv, :] = key_value_states[layer][1][ idx, :, :, :] # [1, 8, L, 128] - return logits, tkv + 1 - - def update_sample_inputs( - self, - logits: torch.Tensor, - ) -> None: - - self.sample_token_id = torch.argmax(logits).clone().detach().reshape( - (1, 1)) - self.sample_position = (self.sample_position[0, -1] + - 1).clone().detach().reshape((1, 1)) - self.sample_mask = torch.nn.functional.pad( - self.sample_mask[0, -1, :].unsqueeze(0), (0, 1)).unsqueeze(0) + return logits class FmsModelPseudoWrapper(FmsModelBaseWrapper): @@ -494,4 +397,4 @@ def forward( for tensor in layer: torch._dynamo.mark_dynamic(tensor, 2) - return logits, tkv + 1 + return logits diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 1ec4a898c..47efdfa3e 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -57,8 +57,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if envs.VLLM_USE_V1: # As of 0.7.3 the scheduler for V1 isn't actually pluggable like # this yet - scheduler_config.scheduler_cls = \ - "vllm_spyre.v1.core.scheduler.SpyreScheduler" + if envs_spyre.VLLM_SPYRE_USE_CB: + scheduler_config.scheduler_cls = \ + "vllm_spyre.v1.core.scheduler.ContinuousBatchingSpyreScheduler" + else: + scheduler_config.scheduler_cls = \ + "vllm_spyre.v1.core.scheduler.SpyreScheduler" else: scheduler_config.scheduler_cls = \ "vllm_spyre.core.scheduler.SpyreScheduler" @@ -74,8 +78,22 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: shape['prompt_length'] + shape['new_tokens']) if envs.VLLM_USE_V1: - # The v0 scheduler will run out of blocks if this is overridden - scheduler_config.max_num_seqs = max_batch_size + if envs_spyre.VLLM_SPYRE_USE_CB: + # For continuous batching we use max_num_seqs to control + # the max batch size respecting AIU Spyre KV cache size + scheduler_config.max_num_seqs =\ + envs_spyre.VLLM_SPYRE_MAX_BATCH_SIZE + # ToDo: this function check_and_update_config is called twice: + # 1st time scheduler_config.max_num_seqs is what user sets + # 2nd time scheduler_config.max_num_seqs is 128 + else: + # The v0 scheduler will run out of blocks if this is overridden + scheduler_config.max_num_seqs = max_batch_size + + # continuous batching related checks + if envs_spyre.VLLM_SPYRE_USE_CB and not envs.VLLM_USE_V1: + raise NotImplementedError( + "Continuous batching is only implemented for vLLM V1") cache_config = vllm_config.cache_config diff --git a/vllm_spyre/v1/core/scheduler.py b/vllm_spyre/v1/core/scheduler.py index f31502479..8c6173033 100644 --- a/vllm_spyre/v1/core/scheduler.py +++ b/vllm_spyre/v1/core/scheduler.py @@ -9,6 +9,8 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +import vllm_spyre.envs as envs_spyre + try: from vllm.v1.core.sched.scheduler import Scheduler except ImportError: @@ -93,7 +95,6 @@ def update_from_output( def schedule(self) -> SchedulerOutput: """This override adds constraints and then delegates most of the work to the base scheduler""" - # First purge the full waiting queue into our holdback queue, preserving # priority while self.waiting: @@ -195,3 +196,90 @@ def _reject_from_queue(self, self.rejected_requests.remove(request.request_id) return reject_outputs + + +class ContinuousBatchingSpyreScheduler(SpyreScheduler): + """ Support of continuous batching """ + + def __init__(self, *args, **kwargs) -> None: + # Initialize SpyreScheduler + super().__init__(*args, **kwargs) + # running queue of last decoding step + self.last_running: list[Request] = [] + self.total_running: list[Request] = [] + self.running: list[Request] = [] + + def add_request(self, request: Request) -> None: + """This override rejects requests that exceed max context length""" + if not request.num_prompt_tokens + request.sampling_params.max_tokens\ + <= envs_spyre.VLLM_SPYRE_MAX_CONTEXT_LENGTH: + logger.warning( + "Could not add request id %s, prompt length is " + "%d tokens, maximum number of output tokens is %d tokens," + "but max model context length is %d.", + request.request_id, + request.num_prompt_tokens, + request.sampling_params.max_tokens, + envs_spyre.VLLM_SPYRE_MAX_CONTEXT_LENGTH, + ) + # TODO: There are open PRs that should enable raising an error for + # a single request like this, which will gracefully return an error + # for the request, instead of shutting down the engine. + # See https://github.com/vllm-project/vllm/pull/11737 + # raise ValueError("Request does not fit any spyre warmup shape") + + # For now, we'll insert a dummy request and manually reject it when + # we construct the outputs later + self.rejected_requests.add(request.request_id) + request.prompt_token_ids = [0] + request.num_prompt_tokens = 1 + request.sampling_params = SamplingParams(max_tokens=1) + + # delegate to super + super(SpyreScheduler, self).add_request(request=request) + + def schedule(self) -> "SchedulerOutput": + """This override adds constraints and then delegates most of the work + to the base scheduler""" + # First purge the full waiting queue into our holdback queue, preserving + # priority + while self.waiting: + self.holdback_queue.append(self.waiting.popleft()) + + # Check if new requests can be scheduled. + self.total_running = self.last_running + self.running + while self.holdback_queue: + if self.can_schedule(): + # Add request to the waiting queue + self.waiting.append(self.holdback_queue.popleft()) + else: + # Otherwise, we simply stop here so that the scheduler + # can work with the batch we have + break + + if len(self.waiting) > 0: + # If prefill scheduled, save running queue for the next decode step. + # If previous step was also prefill, running queue contains the + # previous prefill sequence. + self.last_running = self.total_running + self.running = [] + logger.debug( + "Scheduling a prompt step of %d requests, holding back %d " + "requests", len(self.waiting), len(self.holdback_queue)) + else: + # If decode scheduled and previous step was prefil, update running + # queue + self.running = self.total_running + self.last_running = [] + logger.debug("Scheduling a decode step of %d requests", + len(self.running)) + + outputs = super(SpyreScheduler, self).schedule() + return outputs + + def can_schedule(self) -> bool: + max_prompt_batch_size = 1 + # TODO: add additional checks, e.g. max_tokens + return len(self.total_running)+len(self.waiting) <\ + self.max_num_running_reqs and\ + len(self.waiting) < max_prompt_batch_size diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index bfb535a49..6165ac814 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -36,6 +36,8 @@ from vllm.v1.outputs import ModelRunnerOutput +import vllm_spyre.envs as envs_spyre + logger = init_logger(__name__) TModelInputForSpyre = TypeVar('TModelInputForSpyre', @@ -99,12 +101,131 @@ def __init__( self.device_config = DeviceConfig() self.device = self.device_config.device self.pin_memory = is_pin_memory_available() + + # Lazy initialization: after load_model. + self.model: nn.Module + + def get_model(self) -> nn.Module: + return self.model + + def load_model(self, prompt_lens: Iterable[int], + num_decode_tokens: Iterable[int]) -> None: + max_pad_length = max(prompt_lens) + max_decode_length = max(num_decode_tokens) + self.model = get_spyre_model(self.model_config, + parallel_config=self.parallel_config, + max_prompt_length=max_pad_length, + max_decode_length=max_decode_length) + + @property + def vocab_size(self) -> int: + return self.model.model.model.config.src_vocab_size + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForSpyre: + return ModelInputForSpyre.from_broadcasted_tensor_dict(tensor_dict) + + def _prepare_pad_input_ids( + self, + input_ids_list: List[torch.Tensor], + min_pad_length: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """left side padding implemented as + in fms.utils.generation.pad_input_id""" + max_len = max([min_pad_length] + + [seq.size(0) for seq in input_ids_list]) + padded_input_ids_list = [] + mask_list = [] + position_ids_list = [] + for input_ids_i in input_ids_list: + seq_len = input_ids_i.size(0) + if max_len > seq_len: + logger.info( + "Padding request of length %d tokens to %d tokens.", + seq_len, max_len) + pads = torch.ones(max_len - seq_len, + dtype=torch.long, + device=input_ids_i.device) * self.pad_token_id + non_pads = torch.ones(seq_len, + dtype=torch.long, + device=input_ids_i.device) + + pos_ids_pads = pads + pos_ids_seq = torch.arange(0, + seq_len, + dtype=torch.long, + device=input_ids_i.device) + + # Setting this to 0, however if 0 is the eos, we will end up + # truncating the output if using truncate_after_eos once this + # workflow works for nested tensor, this can probably be removed + padded_input_ids_list.append(torch.cat((pads, input_ids_i))) + mask_list.append(torch.cat((torch.zeros_like(pads), non_pads))) + position_ids_list.append(torch.cat((pos_ids_pads, pos_ids_seq))) + + return padded_input_ids_list, mask_list, position_ids_list + + def pad_input_ids( + self, + input_ids_list: List[torch.Tensor], + min_pad_length: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + padded_input_ids_list, mask_list, position_ids_list = self.\ + _prepare_pad_input_ids(input_ids_list, min_pad_length) + + input_ids = torch.stack(padded_input_ids_list) + mask = torch.stack(mask_list).bool() + # this is a causal mask for generation + mask = (mask.unsqueeze(-1) == mask.unsqueeze(-2)).tril() + mask = torch.where(mask.logical_not(), -torch.inf, 0.0) + mask = mask.to(self.model.model.dtype) + position_ids = torch.stack(position_ids_list) + + return input_ids, position_ids, mask + + def get_kv_cache_spec(self) -> KVCacheSpec: + """ + This method should generate the KVCache spec by parsing the kv cache + format from each Attention module in the static forward context. + + In vLLM, this static forward context is populated by the base Attention + class in the modeling code. Every attention layer populates an entry + for itself in vllm_config.compilation_config.static_forward_context, + which is a dictionary of layer_name -> layer for every attention layer. + This allows the model runner to correctly create the kv cache spec for + each layer. + + The spyre modeling code currently comes from `fms`, and does not + integrate with vLLM's modeling classes, so we don't have access to any + model-agnostic metadata about the attention layers. This just returns a + dummy value for now. + """ + # We do at least use the real size from the cache config. + block_size = self.vllm_config.cache_config.block_size + + attn_spec = FullAttentionSpec(block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float16, + use_mla=False) + return {"foo": attn_spec} + + +class StaticBatchingSpyreModelRunner(SpyreModelRunner): + + def __init__( + self, + vllm_config: VllmConfig, + is_driver_worker: bool, + ): + super().__init__(vllm_config=vllm_config, + is_driver_worker=is_driver_worker) + # position_ids of all the sequences in current batch self._position_ids: torch.Tensor = None # attention masks of all the sequences in current batch self._mask: torch.Tensor = None - # Lazy initialization: after load_model. - self.model: nn.Module # Batch state self.input_batch = InputBatch( @@ -121,22 +242,6 @@ def __init__( self.spyre_warmup_shapes = SpyrePlatform.get_warmup_shapes( self.scheduler_config) - def get_model(self) -> nn.Module: - return self.model - - def load_model(self, prompt_lens: Iterable[int], - num_decode_tokens: Iterable[int]) -> None: - max_pad_length = max(prompt_lens) - max_decode_length = max(num_decode_tokens) - self.model = get_spyre_model(self.model_config, - parallel_config=self.parallel_config, - max_prompt_length=max_pad_length, - max_decode_length=max_decode_length) - - @property - def vocab_size(self) -> int: - return self.model.model.model.config.src_vocab_size - def _prepare_prompt( self, new_requests: list[NewRequestData], @@ -260,10 +365,6 @@ def _update_mask(self) -> None: self._mask = torch.stack(masks_new, dim=0) - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> ModelInputForSpyre: - return ModelInputForSpyre.from_broadcasted_tensor_dict(tensor_dict) - def prepare_model_input( self, scheduler_output: SchedulerOutput) -> ModelInputForSpyre: @@ -353,92 +454,6 @@ def execute_model( ) return model_output - def _prepare_pad_input_ids( - self, - input_ids_list: List[torch.Tensor], - min_pad_length: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """left side padding implemented as - in fms.utils.generation.pad_input_id""" - max_len = max([min_pad_length] + - [seq.size(0) for seq in input_ids_list]) - padded_input_ids_list = [] - mask_list = [] - position_ids_list = [] - for input_ids_i in input_ids_list: - seq_len = input_ids_i.size(0) - if max_len > seq_len: - logger.info( - "Padding request of length %d tokens to %d tokens.", - seq_len, max_len) - pads = torch.ones(max_len - seq_len, - dtype=torch.long, - device=input_ids_i.device) * self.pad_token_id - non_pads = torch.ones(seq_len, - dtype=torch.long, - device=input_ids_i.device) - - pos_ids_pads = pads - pos_ids_seq = torch.arange(0, - seq_len, - dtype=torch.long, - device=input_ids_i.device) - - # Setting this to 0, however if 0 is the eos, we will end up - # truncating the output if using truncate_after_eos once this - # workflow works for nested tensor, this can probably be removed - padded_input_ids_list.append(torch.cat((pads, input_ids_i))) - mask_list.append(torch.cat((torch.zeros_like(pads), non_pads))) - position_ids_list.append(torch.cat((pos_ids_pads, pos_ids_seq))) - - return padded_input_ids_list, mask_list, position_ids_list - - def pad_input_ids( - self, - input_ids_list: List[torch.Tensor], - min_pad_length: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - padded_input_ids_list, mask_list, position_ids_list = self.\ - _prepare_pad_input_ids(input_ids_list, min_pad_length) - - input_ids = torch.stack(padded_input_ids_list) - mask = torch.stack(mask_list).bool() - # this is a causal mask for generation - mask = (mask.unsqueeze(-1) == mask.unsqueeze(-2)).tril() - mask = torch.where(mask.logical_not(), -torch.inf, 0.0) - mask = mask.to(self.model.model.dtype) - position_ids = torch.stack(position_ids_list) - - return input_ids, position_ids, mask - - def get_kv_cache_spec(self) -> KVCacheSpec: - """ - This method should generate the KVCache spec by parsing the kv cache - format from each Attention module in the static forward context. - - In vLLM, this static forward context is populated by the base Attention - class in the modeling code. Every attention layer populates an entry - for itself in vllm_config.compilation_config.static_forward_context, - which is a dictionary of layer_name -> layer for every attention layer. - This allows the model runner to correctly create the kv cache spec for - each layer. - - The spyre modeling code currently comes from `fms`, and does not - integrate with vLLM's modeling classes, so we don't have access to any - model-agnostic metadata about the attention layers. This just returns a - dummy value for now. - """ - # We do at least use the real size from the cache config. - block_size = self.vllm_config.cache_config.block_size - - attn_spec = FullAttentionSpec(block_size=block_size, - num_kv_heads=1, - head_size=1, - dtype=torch.float16, - use_mla=False) - return {"foo": attn_spec} - def _update_states(self, scheduler_output: SchedulerOutput): # Update the states of the running/resumed requests. # For now, we are updating input_batch.'s `token_ids_cpu`, @@ -504,3 +519,231 @@ def _get_padded_batch_size(self, new_requests: list[NewRequestData]): 'prompt_length'] padded_batch_size = applicable_spyre_warmup_shapes[0]['batch_size'] return padded_batch_size, min_pad_length_batch + + +class ContinuousBatchingSpyreModelRunner(SpyreModelRunner): + + def __init__( + self, + vllm_config: VllmConfig, + is_driver_worker: bool, + ): + super().__init__(vllm_config=vllm_config, + is_driver_worker=is_driver_worker) + + max_batch_size = envs_spyre.VLLM_SPYRE_MAX_BATCH_SIZE + # this is just to pass formatting bc type is Optional[list[int]] + if envs_spyre.VLLM_SPYRE_WARMUP_PROMPT_LENS: + max_prompt_length = envs_spyre.VLLM_SPYRE_WARMUP_PROMPT_LENS[0] + + # TO DO: move to InputBatch + self.req_ids2page: dict = {} + self.active_pages: List[int] = [] + self.tkv = 0 + self.free_pages = [i for i in range(max_batch_size)] + self.min_pad_length_batch = max_prompt_length + + def _prepare_prompt( + self, + new_requests: List[NewRequestData], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + assert len(new_requests) > 0 + input_token_list: List[torch.Tensor] = [] + + # Internal state is managed here. + self.active_pages = [] + for request_data in new_requests: + free_page_idx = self.free_pages.pop(0) + self.active_pages.append(free_page_idx) + self.req_ids2page[request_data.req_id] = free_page_idx + + # retrieve initial (unpadded) tokens + prompt_tokens = request_data.prompt_token_ids + + input_token_list.append( + torch.tensor(prompt_tokens, + dtype=torch.long, + device=torch.device("cpu"))) + + # prefils are always of batch size 1 for this milestone + actual_batch_size = len(input_token_list) + assert actual_batch_size == 1 + self.model.indices = torch.ones(actual_batch_size, + dtype=torch.bool, + device='cpu') + + if self.tkv == 0: + self.tkv = self.min_pad_length_batch + + # get position ids and attention mask + input_tokens, position_ids, mask =\ + self.pad_input_ids(input_token_list, min_pad_length=self.tkv) + + seq_lens = [t.shape[0] for t in input_token_list] + + return input_tokens, position_ids, mask, seq_lens + + def _prepare_decode( + self, + cached_requests: List[CachedRequestData], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert len(cached_requests) > 0 + input_tokens = [] + self.active_pages = [] + self.model.indices = torch.ones(len(cached_requests), + dtype=torch.bool, + device='cpu') + + for cached_request in cached_requests: + # TODO: Will this always just be one token ID if there's no spec + # or jump decoding? + self.active_pages.append(self.req_ids2page[cached_request.req_id]) + generation_token = cached_request.new_token_ids[-1] + input_tokens.append([generation_token]) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + mask, position_ids = self._prepare_pos_mask_decode( + cached_requests, self.tkv) + self.tkv = self.tkv + 1 + + return input_tokens, position_ids, mask + + def _prepare_pos_mask_decode( + self, + cached_requests: List[CachedRequestData], + tkv: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + mask_list = [] + position_ids_list = [] + + for cached_request in cached_requests: + seq_len = cached_request.num_computed_tokens + position_ids_list.append([seq_len]) + + pads = torch.ones(tkv - seq_len, + dtype=torch.long, + device=self.device) * self.pad_token_id + non_pads = torch.ones(seq_len + 1, + dtype=torch.long, + device=self.device) + mask_list.append(torch.cat((torch.zeros_like(pads), non_pads))) + + mask = torch.stack(mask_list).bool() + mask = torch.where(mask.logical_not(), -torch.inf, 0.0) + mask = mask.to(self.model.model.dtype) + mask = torch.unsqueeze(mask, dim=1) + position_ids = torch.tensor(position_ids_list, + dtype=torch.long, + device=self.device) + return mask, position_ids + + def prepare_model_input( + self, scheduler_output: SchedulerOutput) -> ModelInputForSpyre: + + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + # Also assuming that new sequences are prefills + is_prompt = len(scheduler_output.scheduled_new_reqs) > 0 + + for req_id in scheduler_output.finished_req_ids: + if req_id in self.req_ids2page: + self.free_pages.append(self.req_ids2page[req_id]) + del self.req_ids2page[req_id] + + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, input_masks, + _) = self._prepare_prompt(scheduler_output.scheduled_new_reqs) + num_reqs = len(scheduler_output.scheduled_new_reqs) + else: + (input_tokens, input_positions, input_masks) = \ + self._prepare_decode(scheduler_output.scheduled_cached_reqs) + num_reqs = len(scheduler_output.scheduled_cached_reqs) + + # TODO: Build the rest of the SamplingMetadata correctly + dummy_tensors = lambda v: torch.full( + (num_reqs, ), v, device=self.device) + dummy_metadata = SamplingMetadata( + temperature=dummy_tensors(0.0), + all_greedy=False, + all_random=False, + top_p=None, + top_k=None, + min_p=None, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=dummy_tensors(0.1), + presence_penalties=dummy_tensors(0.1), + repetition_penalties=dummy_tensors(0.1), + output_token_ids=[[] for _ in range(num_reqs)], + min_tokens={}, + logit_bias=[None for _ in range(num_reqs)], + allowed_token_ids_mask=None, + bad_words_token_ids=None, + ) + + return ModelInputForSpyre(input_tokens=input_tokens, + input_positions=input_positions, + input_masks=input_masks, + sampling_metadata=dummy_metadata, + is_prompt=is_prompt) + + @SpyrePlatform.inference_mode() + def execute_model( + self, + scheduler_output: SchedulerOutput, + **kwargs, + ) -> ModelRunnerOutput: + + t0 = time.time() + model_input = self.prepare_model_input(scheduler_output) + + # Execute the model + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + masks=model_input.input_masks, + is_prompt=model_input.is_prompt, + tkv=0 if model_input.is_prompt else self.tkv, + active_pages=self.active_pages, + ) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, None) + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + t1 = time.time() - t0 + logger.debug("t_token: %.2fms", (t1 * 1000)) + + is_prompt = len(scheduler_output.scheduled_new_reqs) > 0 + scheduled_req = scheduler_output.scheduled_new_reqs if is_prompt\ + else scheduler_output.scheduled_cached_reqs + # since same order as in _prepare_prompt/decode req_ids2idx not needed + req_ids = [req.req_id for req in scheduled_req] + req_id_to_index = {req_id: i for i, req_id in enumerate(req_ids)} + + model_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=output.sampled_token_ids.tolist(), + spec_token_ids=None, + logprobs=output.logprobs_tensors.tolists() + if output.logprobs_tensors else None, + prompt_logprobs_dict={req_id: None + for req_id in req_ids + } # TODO(wallas?): prompt logprobs too + ) + return model_output diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index cdac7e049..f43cc05b7 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -24,7 +24,8 @@ from vllm_spyre.platform import SpyrePlatform from vllm_spyre.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) -from vllm_spyre.v1.worker.spyre_model_runner import SpyreModelRunner +from vllm_spyre.v1.worker.spyre_model_runner import ( + ContinuousBatchingSpyreModelRunner, StaticBatchingSpyreModelRunner) logger = init_logger(__name__) @@ -46,6 +47,10 @@ def get_kv_cache_spec(self) -> KVCacheSpec: def compile_or_warm_up_model(self) -> None: """Prepare model for execution through compilation/warmup.""" + # TO DO: implement warmup for continuous batching + if envs_spyre.VLLM_SPYRE_USE_CB: + return + wup_prompt_lens, wup_new_tokens = zip( *[(s["prompt_length"], s["new_tokens"]) for s in self.spyre_warmup_shapes]) @@ -139,8 +144,12 @@ def __init__( if self.model_config.task == "embed": raise NotImplementedError else: - self.model_runner = SpyreModelRunner(self.vllm_config, - self.is_driver_worker) + if envs_spyre.VLLM_SPYRE_USE_CB: + self.model_runner = ContinuousBatchingSpyreModelRunner( + self.vllm_config, self.is_driver_worker) + else: + self.model_runner = StaticBatchingSpyreModelRunner( + self.vllm_config, self.is_driver_worker) self._env_initialized = False self.spyre_warmup_shapes = SpyrePlatform.get_warmup_shapes( self.vllm_config.scheduler_config)