diff --git a/examples/offline_inference_spyre_cb_test.py b/examples/offline_inference_spyre_cb_test.py new file mode 100644 index 000000000..30ca37e1f --- /dev/null +++ b/examples/offline_inference_spyre_cb_test.py @@ -0,0 +1,83 @@ +import os +import time + +from vllm import LLM, SamplingParams + +max_tokens1 = 10 +max_tokens2 = 5 +max_tokens3 = 7 +max_tokens = max([max_tokens1, max_tokens2, max_tokens3]) +max_num_seqs = 2 # defines max batch size + +os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = '64' +os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max_tokens) + +# defining here to be able to run/debug directly from VSC (not via terminal) +os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = 'eager' +os.environ['VLLM_SPYRE_USE_CB'] = '1' +os.environ['VLLM_USE_V1'] = '1' + +os.environ['VLLM_SPYRE_MAX_CONTEXT_LENGTH'] = '2048' +os.environ['VLLM_SPYRE_MAX_BATCH_SIZE'] = str(max_num_seqs) + +# Sample prompts. +template = ( + "Below is an instruction that describes a task. Write a response that " + "appropriately completes the request. Be polite in your response to the " + "user.\n\n### Instruction:\n{}\n\n### Response:") + +prompt1 = template.format( + "Provide a list of instructions for preparing chicken soup for a family " + "of four.") + +prompt2 = template.format("Provide instructions for preparing chicken soup.") + +prompt3 = template.format( + "Provide a list of instructions for preparing chicken soup for a family.") + +prompts = [ + prompt1, + prompt2, + prompt3, +] + +# Create a sampling params object. +sampling_params1 = SamplingParams(max_tokens=max_tokens1, + temperature=0.0, + ignore_eos=True) + +sampling_params2 = SamplingParams(max_tokens=max_tokens2, + temperature=0.0, + ignore_eos=True) + +sampling_params3 = SamplingParams(max_tokens=max_tokens3, + temperature=0.0, + ignore_eos=True) + +sampling_params = [ + sampling_params1, + sampling_params2, + sampling_params3, +] + +# Create an LLM. +llm = LLM(model="/models/llama-194m", + tokenizer="/models/llama-194m", + max_model_len=2048, + block_size=2048) + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +print("=============== GENERATE") +t0 = time.time() +outputs = llm.generate(prompts, sampling_params) +print("Time elaspsed for %d tokens is %.2f sec" % + (len(outputs[0].outputs[0].token_ids), time.time() - t0)) +print("===============") +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +print("===============") +for output in outputs: + print(output.outputs[0]) diff --git a/vllm_spyre/envs.py b/vllm_spyre/envs.py index 758bb93cf..6eaaa90bd 100644 --- a/vllm_spyre/envs.py +++ b/vllm_spyre/envs.py @@ -6,6 +6,9 @@ VLLM_SPYRE_WARMUP_PROMPT_LENS: Optional[List[int]] = None VLLM_SPYRE_WARMUP_NEW_TOKENS: Optional[List[int]] = None VLLM_SPYRE_WARMUP_BATCH_SIZES: Optional[List[int]] = None + VLLM_SPYRE_USE_CB: bool = False + VLLM_SPYRE_MAX_BATCH_SIZE: int = 0 + VLLM_SPYRE_MAX_CONTEXT_LENGTH: int = 0 environment_variables: Dict[str, Callable[[], Any]] = { # Defines the prompt lengths the Spyre accelerator should be prepared @@ -40,6 +43,18 @@ # - "eager": Skip compile entirely (for debug and testing "VLLM_SPYRE_DYNAMO_BACKEND": lambda: os.getenv("VLLM_SPYRE_DYNAMO_BACKEND", "sendnn_decoder"), + + # If set, use the V1 continuous batching implementation + "VLLM_SPYRE_USE_CB": + lambda: bool(int(os.getenv("VLLM_SPYRE_USE_CB", "0"))), + + # Maximal supported batch size + "VLLM_SPYRE_MAX_BATCH_SIZE": + lambda: int(os.getenv("VLLM_SPYRE_MAX_BATCH_SIZE", "0")), + + # Maximal supported context length + "VLLM_SPYRE_MAX_CONTEXT_LENGTH": + lambda: int(os.getenv("VLLM_SPYRE_MAX_CONTEXT_LENGTH", "0")), } diff --git a/vllm_spyre/model_executor/model_loader/spyre.py b/vllm_spyre/model_executor/model_loader/spyre.py index d5342749b..61ff01403 100644 --- a/vllm_spyre/model_executor/model_loader/spyre.py +++ b/vllm_spyre/model_executor/model_loader/spyre.py @@ -38,23 +38,31 @@ class SpyreCausalLM(nn.Module): def __init__( self, - config: PretrainedConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + max_prompt_length: int, + max_decode_length: int, ) -> None: super().__init__() - self.config = config - self.logits_processor = LogitsProcessor(config.vocab_size, - logits_as_input=True) + + self.logits_processor = LogitsProcessor( + model_config.hf_config.vocab_size, logits_as_input=True) self.sampler = get_sampler() - self.past_key_value_states = None - self.dtype = torch.float16 if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == \ - 'sendnn_decoder' else torch.float32 + # boolean tensor of length batch size with indices: # True for unfinished sequences and # False for finished or padded sequences self.indices = None - # Lazy initialized - self.model: nn.Module + # FMS Model + fms_model = ContinuousBatchingFmsModel if envs_spyre.VLLM_SPYRE_USE_CB\ + else StaticBatchingFmsModel + self.model = fms_model( + model_config, + parallel_config, + max_prompt_length, + max_decode_length, + ) def forward( self, @@ -62,10 +70,12 @@ def forward( positions: torch.Tensor, masks: torch.Tensor, is_prompt: bool, + tkv: Optional[int] = None, + active_pages: Optional[list[int]] = None, ) -> torch.Tensor: - if is_prompt: - self.past_key_value_states = None + if is_prompt and not envs_spyre.VLLM_SPYRE_USE_CB: + self.model.past_key_value_states = None extra_kwargs = {} if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder": @@ -73,32 +83,28 @@ def forward( # cpu impl when padding too much extra_kwargs["attn_algorithm"] = "math" - output = self.model( + # normal prefil or decoding step + logits = self.model( input_ids, position_ids=positions, mask=masks, - past_key_value_states=self.past_key_value_states, use_cache=True, only_last_token=True, + tkv=tkv, + active_pages=active_pages, **extra_kwargs, ) - logits, past_key_value_states = output - self.past_key_value_states = past_key_value_states - - # mark dynamic - if self.past_key_value_states is not None: - for layer in self.past_key_value_states: - for tensor in layer: - torch._dynamo.mark_dynamic(tensor, 2) - # removing finished or padded sequences logits = logits[self.indices] return logits - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: logits = self.logits_processor(None, hidden_states, sampling_metadata) return logits @@ -110,9 +116,40 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, model_config: ModelConfig, max_prompt_length: int, - max_decode_length: int, - distributed_strategy: Optional[str], **kwargs): + +class FmsModelBase(nn.Module): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + max_prompt_length: int, + max_decode_length: int, + ) -> None: + super().__init__() + + self.config: PretrainedConfig = model_config.hf_config + self.dtype = torch.float16 if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == \ + 'sendnn_decoder' else torch.float32 + + # Actual FMS model + self.model: nn.Module + + # Load the weights from the cached or downloaded files. + self.load_weights(model_config=model_config, + max_prompt_length=max_prompt_length, + max_decode_length=max_decode_length, + distributed_strategy="tp" + if parallel_config.world_size > 1 else None) + + def load_weights( + self, + model_config: ModelConfig, + max_prompt_length: int, + max_decode_length: int, + distributed_strategy: Optional[str], + **kwargs, + ) -> None: if self.dtype is not model_config.dtype: logger.info( @@ -206,16 +243,141 @@ def load_weights(self, model_config: ModelConfig, max_prompt_length: int, backend=envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND) -def get_spyre_model(model_config: ModelConfig, parallel_config: ParallelConfig, - max_prompt_length, max_decode_length) -> nn.Module: +class ContinuousBatchingFmsModel(FmsModelBase): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + max_prompt_length: int, + max_decode_length: int, + ) -> None: + super().__init__(model_config, parallel_config, max_prompt_length, + max_decode_length) + + # physical KV cache on AIU Spyre + max_batch = envs_spyre.VLLM_SPYRE_MAX_BATCH_SIZE + max_model_len = envs_spyre.VLLM_SPYRE_MAX_CONTEXT_LENGTH + + if self.config.model_type == 'llama': + num_layers = self.config.num_hidden_layers + num_kv_heads = self.config.num_key_value_heads + head_dim = self.config.hidden_size // \ + self.config.num_attention_heads + elif self.config.model_type == 'gpt_bigcode': + num_layers = self.config.n_layer + num_kv_heads = 1 if self.config.multi_query else self.config.n_head + head_dim = self.config.n_embd // self.config.n_head + else: + print(f"[SpyreCausalLM] model type {self.config.model_type} " + f"not supported in ContinuousBatchingFmsModel") - # Create a model instance. - model = SpyreCausalLM(model_config.hf_config) + # (layers)x(k,v)x[max_batch, num_kv_heads, max_model_len, head_dim] + self.fms_kv_cache: list[tuple[torch.Tensor, torch.Tensor]] = [ + (torch.empty((max_batch, num_kv_heads, max_model_len, head_dim)), + torch.empty((max_batch, num_kv_heads, max_model_len, head_dim))) + for i in range(num_layers) + ] - # Load the weights from the cached or downloaded files. - model.load_weights( - model_config, - max_prompt_length=max_prompt_length, - max_decode_length=max_decode_length, - distributed_strategy="tp" if parallel_config.world_size > 1 else None) - return model + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + mask: torch.Tensor, + use_cache: bool, + only_last_token: bool, + tkv: int, + active_pages: list[int], + **extra_kwargs, + ) -> torch.Tensor: + + # read-out (dynamic) kv_cache for decoding steps only, + # for prefills kv_cache = None + if tkv == 0: # prefil + kv_cache = None + tkv = input_ids.shape[1] + else: # decode + kv_cache = [] + active_pages_mask = torch.zeros(self.fms_kv_cache[0][0].shape[0], + dtype=torch.bool) + active_pages_mask[active_pages] = True + for layer in range(len(self.fms_kv_cache)): + kv_cache.append( + (self.fms_kv_cache[layer][0][active_pages_mask, :, :tkv - + 1, :], + self.fms_kv_cache[layer][1][active_pages_mask, :, :tkv - + 1, :])) + + output = self.model( + input_ids, + position_ids=position_ids, + mask=mask, + past_key_value_states=kv_cache, + use_cache=use_cache, + only_last_token=only_last_token, + **extra_kwargs, + ) + logits, key_value_states = output + + # updating (physical) KV cache: self.fms_kv_cache + for idx, page in enumerate(sorted(active_pages)): + for layer in range(len(self.fms_kv_cache)): + # inserting partial KV cache at correct location + # (page, tkv) in the KV cache of the whole batch + self.fms_kv_cache[layer][0][ + page, :, :tkv, :] = key_value_states[layer][0][ + idx, :, :, :] # [1, 8, L, 128] + self.fms_kv_cache[layer][1][ + page, :, :tkv, :] = key_value_states[layer][1][ + idx, :, :, :] # [1, 8, L, 128] + + return logits + + +class StaticBatchingFmsModel(FmsModelBase): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + max_prompt_length: int, + max_decode_length: int, + ) -> None: + super().__init__(model_config, parallel_config, max_prompt_length, + max_decode_length) + + # dynamic KV cache + self.past_key_value_states = None + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + mask: torch.Tensor, + use_cache: bool, + only_last_token: bool, + tkv: int, + active_pages: list[int], + **extra_kwargs, + ) -> torch.Tensor: + + output = self.model( + input_ids, + position_ids=position_ids, + mask=mask, + past_key_value_states=self.past_key_value_states, + use_cache=use_cache, + only_last_token=only_last_token, + **extra_kwargs, + ) + + logits, past_key_value_states = output + self.past_key_value_states = past_key_value_states + + # mark dynamic + if self.past_key_value_states is not None: + for layer in self.past_key_value_states: + for tensor in layer: + torch._dynamo.mark_dynamic(tensor, 2) + + return logits diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 1ec4a898c..47efdfa3e 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -57,8 +57,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if envs.VLLM_USE_V1: # As of 0.7.3 the scheduler for V1 isn't actually pluggable like # this yet - scheduler_config.scheduler_cls = \ - "vllm_spyre.v1.core.scheduler.SpyreScheduler" + if envs_spyre.VLLM_SPYRE_USE_CB: + scheduler_config.scheduler_cls = \ + "vllm_spyre.v1.core.scheduler.ContinuousBatchingSpyreScheduler" + else: + scheduler_config.scheduler_cls = \ + "vllm_spyre.v1.core.scheduler.SpyreScheduler" else: scheduler_config.scheduler_cls = \ "vllm_spyre.core.scheduler.SpyreScheduler" @@ -74,8 +78,22 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: shape['prompt_length'] + shape['new_tokens']) if envs.VLLM_USE_V1: - # The v0 scheduler will run out of blocks if this is overridden - scheduler_config.max_num_seqs = max_batch_size + if envs_spyre.VLLM_SPYRE_USE_CB: + # For continuous batching we use max_num_seqs to control + # the max batch size respecting AIU Spyre KV cache size + scheduler_config.max_num_seqs =\ + envs_spyre.VLLM_SPYRE_MAX_BATCH_SIZE + # ToDo: this function check_and_update_config is called twice: + # 1st time scheduler_config.max_num_seqs is what user sets + # 2nd time scheduler_config.max_num_seqs is 128 + else: + # The v0 scheduler will run out of blocks if this is overridden + scheduler_config.max_num_seqs = max_batch_size + + # continuous batching related checks + if envs_spyre.VLLM_SPYRE_USE_CB and not envs.VLLM_USE_V1: + raise NotImplementedError( + "Continuous batching is only implemented for vLLM V1") cache_config = vllm_config.cache_config diff --git a/vllm_spyre/v1/core/scheduler.py b/vllm_spyre/v1/core/scheduler.py index cc85e3ce5..d9e33a42c 100644 --- a/vllm_spyre/v1/core/scheduler.py +++ b/vllm_spyre/v1/core/scheduler.py @@ -9,6 +9,8 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +import vllm_spyre.envs as envs_spyre + try: from vllm.v1.core.sched.scheduler import Scheduler except ImportError: @@ -94,7 +96,6 @@ def update_from_output( def schedule(self) -> SchedulerOutput: """This override adds constraints and then delegates most of the work to the base scheduler""" - # First purge the full waiting queue into our holdback queue, preserving # priority while self.waiting: @@ -199,3 +200,91 @@ def _reject_from_queue(self, self.rejected_requests.remove(request.request_id) return reject_outputs + + +class ContinuousBatchingSpyreScheduler(SpyreScheduler): + """ Support of continuous batching """ + + def __init__(self, *args, **kwargs) -> None: + # Initialize SpyreScheduler + super().__init__(*args, **kwargs) + # running queue of last decoding step + self.last_running: list[Request] = [] + self.total_running: list[Request] = [] + self.running: list[Request] = [] + + def add_request(self, request: Request) -> None: + """This override rejects requests that exceed max context length""" + if not request.num_prompt_tokens + request.sampling_params.max_tokens\ + <= envs_spyre.VLLM_SPYRE_MAX_CONTEXT_LENGTH: + logger.warning( + "Could not add request id %s, prompt length is " + "%d tokens, maximum number of output tokens is %d tokens," + "but max model context length is %d.", + request.request_id, + request.num_prompt_tokens, + request.sampling_params.max_tokens, + envs_spyre.VLLM_SPYRE_MAX_CONTEXT_LENGTH, + ) + # TODO: There are open PRs that should enable raising an error for + # a single request like this, which will gracefully return an error + # for the request, instead of shutting down the engine. + # See https://github.com/vllm-project/vllm/pull/11737 + # raise ValueError("Request does not fit any spyre warmup shape") + + # For now, we'll insert a dummy request and manually reject it when + # we construct the outputs later + self.rejected_requests.add(request.request_id) + request.prompt_token_ids = [0] + request.num_prompt_tokens = 1 + request.sampling_params = SamplingParams(max_tokens=1) + + # delegate to super of SpyreScheduler: base V1 Scheduler + super(SpyreScheduler, self).add_request(request=request) + + def schedule(self) -> "SchedulerOutput": + """This override adds constraints and then delegates most of the work + to the base scheduler""" + # First purge the full waiting queue into our holdback queue, preserving + # priority + while self.waiting: + self.holdback_queue.append(self.waiting.popleft()) + + # Check if new requests can be scheduled. + self.total_running = self.last_running + self.running + while self.holdback_queue: + if self.can_schedule(): + # Add request to the waiting queue + self.waiting.append(self.holdback_queue.popleft()) + else: + # Otherwise, we simply stop here so that the scheduler + # can work with the batch we have + break + + if len(self.waiting) > 0: + # If prefill scheduled, save running queue for the next decode step. + # If previous step was also prefill, running queue contains the + # previous prefill sequence. + self.last_running = self.total_running + self.running = [] + logger.debug( + "Scheduling a prompt step of %d requests, holding back %d " + "requests", len(self.waiting), len(self.holdback_queue)) + else: + # If decode scheduled and previous step was prefil, update running + # queue + self.running = self.total_running + self.last_running = [] + logger.debug("Scheduling a decode step of %d requests", + len(self.running)) + + # delegate to super of SpyreScheduler: base V1 Scheduler + outputs = super(SpyreScheduler, self).schedule() + return outputs + + def can_schedule(self) -> bool: + max_prompt_batch_size = 1 + # TODO: add additional checks, e.g. max_tokens + return len(self.total_running)+len(self.waiting) <\ + self.max_num_running_reqs and\ + len(self.waiting) < max_prompt_batch_size diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 46c0cc83c..48580402a 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -17,7 +17,7 @@ _add_sampling_metadata_broadcastable_dict, _init_sampling_metadata_from_tensor_dict) -from vllm_spyre.model_executor.model_loader.spyre import get_spyre_model +from vllm_spyre.model_executor.model_loader.spyre import SpyreCausalLM from vllm_spyre.platform import SpyrePlatform from vllm_spyre.v1.worker.spyre_input_batch import (CachedRequestState, InputBatch) @@ -36,6 +36,8 @@ from vllm.v1.outputs import ModelRunnerOutput +import vllm_spyre.envs as envs_spyre + logger = init_logger(__name__) TModelInputForSpyre = TypeVar('TModelInputForSpyre', @@ -99,12 +101,131 @@ def __init__( self.device_config = DeviceConfig() self.device = self.device_config.device self.pin_memory = is_pin_memory_available() + + # Lazy initialization: after load_model. + self.model: nn.Module + + def get_model(self) -> nn.Module: + return self.model + + def load_model(self, prompt_lens: Iterable[int], + num_decode_tokens: Iterable[int]) -> None: + max_pad_length = max(prompt_lens) + max_decode_length = max(num_decode_tokens) + self.model = SpyreCausalLM(self.model_config, + parallel_config=self.parallel_config, + max_prompt_length=max_pad_length, + max_decode_length=max_decode_length) + + @property + def vocab_size(self) -> int: + return self.model.model.model.config.src_vocab_size + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForSpyre: + return ModelInputForSpyre.from_broadcasted_tensor_dict(tensor_dict) + + def _prepare_pad_input_ids( + self, + input_ids_list: List[torch.Tensor], + min_pad_length: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """left side padding implemented as + in fms.utils.generation.pad_input_id""" + max_len = max([min_pad_length] + + [seq.size(0) for seq in input_ids_list]) + padded_input_ids_list = [] + mask_list = [] + position_ids_list = [] + for input_ids_i in input_ids_list: + seq_len = input_ids_i.size(0) + if max_len > seq_len: + logger.info( + "Padding request of length %d tokens to %d tokens.", + seq_len, max_len) + pads = torch.ones(max_len - seq_len, + dtype=torch.long, + device=input_ids_i.device) * self.pad_token_id + non_pads = torch.ones(seq_len, + dtype=torch.long, + device=input_ids_i.device) + + pos_ids_pads = pads + pos_ids_seq = torch.arange(0, + seq_len, + dtype=torch.long, + device=input_ids_i.device) + + # Setting this to 0, however if 0 is the eos, we will end up + # truncating the output if using truncate_after_eos once this + # workflow works for nested tensor, this can probably be removed + padded_input_ids_list.append(torch.cat((pads, input_ids_i))) + mask_list.append(torch.cat((torch.zeros_like(pads), non_pads))) + position_ids_list.append(torch.cat((pos_ids_pads, pos_ids_seq))) + + return padded_input_ids_list, mask_list, position_ids_list + + def pad_input_ids( + self, + input_ids_list: List[torch.Tensor], + min_pad_length: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + padded_input_ids_list, mask_list, position_ids_list = self.\ + _prepare_pad_input_ids(input_ids_list, min_pad_length) + + input_ids = torch.stack(padded_input_ids_list) + mask = torch.stack(mask_list).bool() + # this is a causal mask for generation + mask = (mask.unsqueeze(-1) == mask.unsqueeze(-2)).tril() + mask = torch.where(mask.logical_not(), -torch.inf, 0.0) + mask = mask.to(self.model.model.dtype) + position_ids = torch.stack(position_ids_list) + + return input_ids, position_ids, mask + + def get_kv_cache_spec(self) -> KVCacheSpec: + """ + This method should generate the KVCache spec by parsing the kv cache + format from each Attention module in the static forward context. + + In vLLM, this static forward context is populated by the base Attention + class in the modeling code. Every attention layer populates an entry + for itself in vllm_config.compilation_config.static_forward_context, + which is a dictionary of layer_name -> layer for every attention layer. + This allows the model runner to correctly create the kv cache spec for + each layer. + + The spyre modeling code currently comes from `fms`, and does not + integrate with vLLM's modeling classes, so we don't have access to any + model-agnostic metadata about the attention layers. This just returns a + dummy value for now. + """ + # We do at least use the real size from the cache config. + block_size = self.vllm_config.cache_config.block_size + + attn_spec = FullAttentionSpec(block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float16, + use_mla=False) + return {"foo": attn_spec} + + +class StaticBatchingSpyreModelRunner(SpyreModelRunner): + + def __init__( + self, + vllm_config: VllmConfig, + is_driver_worker: bool, + ): + super().__init__(vllm_config=vllm_config, + is_driver_worker=is_driver_worker) + # position_ids of all the sequences in current batch self._position_ids: torch.Tensor = None # attention masks of all the sequences in current batch self._mask: torch.Tensor = None - # Lazy initialization: after load_model. - self.model: nn.Module # Batch state self.input_batch = InputBatch( @@ -121,22 +242,6 @@ def __init__( self.spyre_warmup_shapes = SpyrePlatform.get_warmup_shapes( self.scheduler_config) - def get_model(self) -> nn.Module: - return self.model - - def load_model(self, prompt_lens: Iterable[int], - num_decode_tokens: Iterable[int]) -> None: - max_pad_length = max(prompt_lens) - max_decode_length = max(num_decode_tokens) - self.model = get_spyre_model(self.model_config, - parallel_config=self.parallel_config, - max_prompt_length=max_pad_length, - max_decode_length=max_decode_length) - - @property - def vocab_size(self) -> int: - return self.model.model.config.src_vocab_size - def _prepare_prompt( self, new_requests: list[NewRequestData], @@ -260,10 +365,6 @@ def _update_mask(self) -> None: self._mask = torch.stack(masks_new, dim=0) - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> ModelInputForSpyre: - return ModelInputForSpyre.from_broadcasted_tensor_dict(tensor_dict) - def prepare_model_input( self, scheduler_output: SchedulerOutput) -> ModelInputForSpyre: @@ -367,92 +468,6 @@ def execute_model( ) return model_output - def _prepare_pad_input_ids( - self, - input_ids_list: List[torch.Tensor], - min_pad_length: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """left side padding implemented as - in fms.utils.generation.pad_input_id""" - max_len = max([min_pad_length] + - [seq.size(0) for seq in input_ids_list]) - padded_input_ids_list = [] - mask_list = [] - position_ids_list = [] - for input_ids_i in input_ids_list: - seq_len = input_ids_i.size(0) - if max_len > seq_len: - logger.info( - "Padding request of length %d tokens to %d tokens.", - seq_len, max_len) - pads = torch.ones(max_len - seq_len, - dtype=torch.long, - device=input_ids_i.device) * self.pad_token_id - non_pads = torch.ones(seq_len, - dtype=torch.long, - device=input_ids_i.device) - - pos_ids_pads = pads - pos_ids_seq = torch.arange(0, - seq_len, - dtype=torch.long, - device=input_ids_i.device) - - # Setting this to 0, however if 0 is the eos, we will end up - # truncating the output if using truncate_after_eos once this - # workflow works for nested tensor, this can probably be removed - padded_input_ids_list.append(torch.cat((pads, input_ids_i))) - mask_list.append(torch.cat((torch.zeros_like(pads), non_pads))) - position_ids_list.append(torch.cat((pos_ids_pads, pos_ids_seq))) - - return padded_input_ids_list, mask_list, position_ids_list - - def pad_input_ids( - self, - input_ids_list: List[torch.Tensor], - min_pad_length: int = 0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - padded_input_ids_list, mask_list, position_ids_list = self.\ - _prepare_pad_input_ids(input_ids_list, min_pad_length) - - input_ids = torch.stack(padded_input_ids_list) - mask = torch.stack(mask_list).bool() - # this is a causal mask for generation - mask = (mask.unsqueeze(-1) == mask.unsqueeze(-2)).tril() - mask = torch.where(mask.logical_not(), -torch.inf, 0.0) - mask = mask.to(self.model.dtype) - position_ids = torch.stack(position_ids_list) - - return input_ids, position_ids, mask - - def get_kv_cache_spec(self) -> KVCacheSpec: - """ - This method should generate the KVCache spec by parsing the kv cache - format from each Attention module in the static forward context. - - In vLLM, this static forward context is populated by the base Attention - class in the modeling code. Every attention layer populates an entry - for itself in vllm_config.compilation_config.static_forward_context, - which is a dictionary of layer_name -> layer for every attention layer. - This allows the model runner to correctly create the kv cache spec for - each layer. - - The spyre modeling code currently comes from `fms`, and does not - integrate with vLLM's modeling classes, so we don't have access to any - model-agnostic metadata about the attention layers. This just returns a - dummy value for now. - """ - # We do at least use the real size from the cache config. - block_size = self.vllm_config.cache_config.block_size - - attn_spec = FullAttentionSpec(block_size=block_size, - num_kv_heads=1, - head_size=1, - dtype=torch.float16, - use_mla=False) - return {"foo": attn_spec} - def _update_states(self, scheduler_output: SchedulerOutput): # Update the states of the running/resumed requests. # For now, we are updating input_batch.'s `token_ids_cpu`, @@ -518,3 +533,231 @@ def _get_padded_batch_size(self, new_requests: list[NewRequestData]): 'prompt_length'] padded_batch_size = applicable_spyre_warmup_shapes[0]['batch_size'] return padded_batch_size, min_pad_length_batch + + +class ContinuousBatchingSpyreModelRunner(SpyreModelRunner): + + def __init__( + self, + vllm_config: VllmConfig, + is_driver_worker: bool, + ): + super().__init__(vllm_config=vllm_config, + is_driver_worker=is_driver_worker) + + max_batch_size = envs_spyre.VLLM_SPYRE_MAX_BATCH_SIZE + # this is just to pass formatting bc type is Optional[list[int]] + if envs_spyre.VLLM_SPYRE_WARMUP_PROMPT_LENS: + max_prompt_length = envs_spyre.VLLM_SPYRE_WARMUP_PROMPT_LENS[0] + + # TO DO: move to InputBatch + self.req_ids2page: dict = {} + self.active_pages: List[int] = [] + self.tkv = 0 + self.free_pages = [i for i in range(max_batch_size)] + self.min_pad_length_batch = max_prompt_length + + def _prepare_prompt( + self, + new_requests: List[NewRequestData], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + assert len(new_requests) > 0 + input_token_list: List[torch.Tensor] = [] + + # Internal state is managed here. + self.active_pages = [] + for request_data in new_requests: + free_page_idx = self.free_pages.pop(0) + self.active_pages.append(free_page_idx) + self.req_ids2page[request_data.req_id] = free_page_idx + + # retrieve initial (unpadded) tokens + prompt_tokens = request_data.prompt_token_ids + + input_token_list.append( + torch.tensor(prompt_tokens, + dtype=torch.long, + device=torch.device("cpu"))) + + # prefils are always of batch size 1 for this milestone + actual_batch_size = len(input_token_list) + assert actual_batch_size == 1 + self.model.indices = torch.ones(actual_batch_size, + dtype=torch.bool, + device='cpu') + + if self.tkv == 0: + self.tkv = self.min_pad_length_batch + + # get position ids and attention mask + input_tokens, position_ids, mask =\ + self.pad_input_ids(input_token_list, min_pad_length=self.tkv) + + seq_lens = [t.shape[0] for t in input_token_list] + + return input_tokens, position_ids, mask, seq_lens + + def _prepare_decode( + self, + cached_requests: List[CachedRequestData], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert len(cached_requests) > 0 + input_tokens = [] + self.active_pages = [] + self.model.indices = torch.ones(len(cached_requests), + dtype=torch.bool, + device='cpu') + + for cached_request in cached_requests: + # TODO: Will this always just be one token ID if there's no spec + # or jump decoding? + self.active_pages.append(self.req_ids2page[cached_request.req_id]) + generation_token = cached_request.new_token_ids[-1] + input_tokens.append([generation_token]) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + mask, position_ids = self._prepare_pos_mask_decode( + cached_requests, self.tkv) + self.tkv = self.tkv + 1 + + return input_tokens, position_ids, mask + + def _prepare_pos_mask_decode( + self, + cached_requests: List[CachedRequestData], + tkv: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + mask_list = [] + position_ids_list = [] + + for cached_request in cached_requests: + seq_len = cached_request.num_computed_tokens + position_ids_list.append([seq_len]) + + pads = torch.ones(tkv - seq_len, + dtype=torch.long, + device=self.device) * self.pad_token_id + non_pads = torch.ones(seq_len + 1, + dtype=torch.long, + device=self.device) + mask_list.append(torch.cat((torch.zeros_like(pads), non_pads))) + + mask = torch.stack(mask_list).bool() + mask = torch.where(mask.logical_not(), -torch.inf, 0.0) + mask = mask.to(self.model.model.dtype) + mask = torch.unsqueeze(mask, dim=1) + position_ids = torch.tensor(position_ids_list, + dtype=torch.long, + device=self.device) + return mask, position_ids + + def prepare_model_input( + self, scheduler_output: SchedulerOutput) -> ModelInputForSpyre: + + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + # Also assuming that new sequences are prefills + is_prompt = len(scheduler_output.scheduled_new_reqs) > 0 + + for req_id in scheduler_output.finished_req_ids: + if req_id in self.req_ids2page: + self.free_pages.append(self.req_ids2page[req_id]) + del self.req_ids2page[req_id] + + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, input_masks, + _) = self._prepare_prompt(scheduler_output.scheduled_new_reqs) + num_reqs = len(scheduler_output.scheduled_new_reqs) + else: + (input_tokens, input_positions, input_masks) = \ + self._prepare_decode(scheduler_output.scheduled_cached_reqs) + num_reqs = len(scheduler_output.scheduled_cached_reqs) + + # TODO: Build the rest of the SamplingMetadata correctly + dummy_tensors = lambda v: torch.full( + (num_reqs, ), v, device=self.device) + dummy_metadata = SamplingMetadata( + temperature=dummy_tensors(0.0), + all_greedy=False, + all_random=False, + top_p=None, + top_k=None, + min_p=None, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=dummy_tensors(0.1), + presence_penalties=dummy_tensors(0.1), + repetition_penalties=dummy_tensors(0.1), + output_token_ids=[[] for _ in range(num_reqs)], + min_tokens={}, + logit_bias=[None for _ in range(num_reqs)], + allowed_token_ids_mask=None, + bad_words_token_ids=None, + ) + + return ModelInputForSpyre(input_tokens=input_tokens, + input_positions=input_positions, + input_masks=input_masks, + sampling_metadata=dummy_metadata, + is_prompt=is_prompt) + + @SpyrePlatform.inference_mode() + def execute_model( + self, + scheduler_output: SchedulerOutput, + **kwargs, + ) -> ModelRunnerOutput: + + t0 = time.time() + model_input = self.prepare_model_input(scheduler_output) + + # Execute the model + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + masks=model_input.input_masks, + is_prompt=model_input.is_prompt, + tkv=0 if model_input.is_prompt else self.tkv, + active_pages=self.active_pages, + ) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, None) + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + t1 = time.time() - t0 + logger.debug("t_token: %.2fms", (t1 * 1000)) + + is_prompt = len(scheduler_output.scheduled_new_reqs) > 0 + scheduled_req = scheduler_output.scheduled_new_reqs if is_prompt\ + else scheduler_output.scheduled_cached_reqs + # since same order as in _prepare_prompt/decode req_ids2idx not needed + req_ids = [req.req_id for req in scheduled_req] + req_id_to_index = {req_id: i for i, req_id in enumerate(req_ids)} + + model_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_id_to_index, + sampled_token_ids=output.sampled_token_ids.tolist(), + spec_token_ids=None, + logprobs=output.logprobs_tensors.tolists() + if output.logprobs_tensors else None, + prompt_logprobs_dict={req_id: None + for req_id in req_ids + } # TODO(wallas?): prompt logprobs too + ) + return model_output diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index cdac7e049..f43cc05b7 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -24,7 +24,8 @@ from vllm_spyre.platform import SpyrePlatform from vllm_spyre.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) -from vllm_spyre.v1.worker.spyre_model_runner import SpyreModelRunner +from vllm_spyre.v1.worker.spyre_model_runner import ( + ContinuousBatchingSpyreModelRunner, StaticBatchingSpyreModelRunner) logger = init_logger(__name__) @@ -46,6 +47,10 @@ def get_kv_cache_spec(self) -> KVCacheSpec: def compile_or_warm_up_model(self) -> None: """Prepare model for execution through compilation/warmup.""" + # TO DO: implement warmup for continuous batching + if envs_spyre.VLLM_SPYRE_USE_CB: + return + wup_prompt_lens, wup_new_tokens = zip( *[(s["prompt_length"], s["new_tokens"]) for s in self.spyre_warmup_shapes]) @@ -139,8 +144,12 @@ def __init__( if self.model_config.task == "embed": raise NotImplementedError else: - self.model_runner = SpyreModelRunner(self.vllm_config, - self.is_driver_worker) + if envs_spyre.VLLM_SPYRE_USE_CB: + self.model_runner = ContinuousBatchingSpyreModelRunner( + self.vllm_config, self.is_driver_worker) + else: + self.model_runner = StaticBatchingSpyreModelRunner( + self.vllm_config, self.is_driver_worker) self._env_initialized = False self.spyre_warmup_shapes = SpyrePlatform.get_warmup_shapes( self.vllm_config.scheduler_config) diff --git a/vllm_spyre/worker/spyre_model_runner.py b/vllm_spyre/worker/spyre_model_runner.py index e91e77c2c..d6f9da180 100644 --- a/vllm_spyre/worker/spyre_model_runner.py +++ b/vllm_spyre/worker/spyre_model_runner.py @@ -18,7 +18,7 @@ _add_sampling_metadata_broadcastable_dict, _init_sampling_metadata_from_tensor_dict) -from vllm_spyre.model_executor.model_loader.spyre import get_spyre_model +from vllm_spyre.model_executor.model_loader.spyre import SpyreCausalLM if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -111,14 +111,14 @@ def load_model(self, prompt_lens: Iterable[int], num_decode_tokens: Iterable[int]) -> None: max_pad_length = max(prompt_lens) max_decode_length = max(num_decode_tokens) - self.model = get_spyre_model(self.model_config, - parallel_config=self.parallel_config, - max_prompt_length=max_pad_length, - max_decode_length=max_decode_length) + self.model = SpyreCausalLM(self.model_config, + parallel_config=self.parallel_config, + max_prompt_length=max_pad_length, + max_decode_length=max_decode_length) @property def vocab_size(self) -> int: - return self.model.model.config.src_vocab_size + return self.model.model.model.config.src_vocab_size def _prepare_prompt( self, @@ -403,7 +403,7 @@ def pad_input_ids( # this is a causal mask for generation mask = (mask.unsqueeze(-1) == mask.unsqueeze(-2)).tril() mask = torch.where(mask.logical_not(), -torch.inf, 0.0) - mask = mask.to(self.model.dtype) + mask = mask.to(self.model.model.dtype) position_ids = torch.stack(position_ids_list) return input_ids, position_ids, mask @@ -421,10 +421,11 @@ def _raw_model_forward( ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: - return self.model.model(input_ids, - mask=mask, - position_ids=position_ids, - past_key_value_states=past_key_value_states, - use_cache=use_cache, - only_last_token=only_last_token, - attn_algorithm=attn_algorithm) + return self.model.model.model( + input_ids, + mask=mask, + position_ids=position_ids, + past_key_value_states=past_key_value_states, + use_cache=use_cache, + only_last_token=only_last_token, + attn_algorithm=attn_algorithm)