diff --git a/docs/user_guide/configuration.md b/docs/user_guide/configuration.md index 0e30e2562..fb4a5d805 100644 --- a/docs/user_guide/configuration.md +++ b/docs/user_guide/configuration.md @@ -13,8 +13,8 @@ To run inference on IBM Spyre Accelerators, the backend should be set as: | --- | --- | --- | --- | | Decoder | v0 | sendnn | V0 support for decoder models is deprecated | | Decoder | v1 | sendnn | | -| Embedding | v0 | sendnn | | -| Embedding | v1 | N/A | Embedding models are not yet supported on V1 | +| Embedding | v0 | sendnn | V0 support for embedding models is deprecated| +| Embedding | v1 | sendnn | | ## Batching Modes diff --git a/docs/user_guide/supported_features.md b/docs/user_guide/supported_features.md index 196b237c3..f45b388f9 100644 --- a/docs/user_guide/supported_features.md +++ b/docs/user_guide/supported_features.md @@ -10,7 +10,7 @@ This table summarize the status of features on Spyre. By default, those features | Prompt Adapter | ⛔ | Being deprecated in vLLM [vllm#13981](https://github.com/vllm-project/vllm/issues/13981) | | Speculative Decoding | 🗓️ | | | Guided Decoding | 🗓️ | | -| Pooling | ⚠️ | Works with V0. V1 still being developed in vLLM [vllm#18052](https://github.com/vllm-project/vllm/issues/18052) | +| Pooling | ✅ | | | Enc-dec | ⛔ | No plans for now | | Multi Modality | 🗓️ | | | LogProbs | ✅ | | diff --git a/examples/offline_inference/cb_spyre_inference.py b/examples/offline_inference/cb_spyre_inference.py index 04aa59390..d11c23bd5 100644 --- a/examples/offline_inference/cb_spyre_inference.py +++ b/examples/offline_inference/cb_spyre_inference.py @@ -42,7 +42,6 @@ if "VLLM_SPYRE_DYNAMO_BACKEND" not in os.environ: os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = 'eager' os.environ['VLLM_SPYRE_USE_CB'] = '1' -os.environ['VLLM_USE_V1'] = '1' template = ( "Below is an instruction that describes a task. Write a response that " diff --git a/examples/offline_inference/long_context.py b/examples/offline_inference/long_context.py index e52c56221..7e7deb5e8 100644 --- a/examples/offline_inference/long_context.py +++ b/examples/offline_inference/long_context.py @@ -68,7 +68,6 @@ if "VLLM_SPYRE_DYNAMO_BACKEND" not in os.environ: os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = 'eager' os.environ['VLLM_SPYRE_USE_CB'] = '1' -os.environ['VLLM_USE_V1'] = '1' template = ("Summarize the following code: \n\n{}") diff --git a/tests/conftest.py b/tests/conftest.py index 7d5a89c51..c8432f989 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,5 @@ -# 🌶️🌶️🌶️ Hack to allow testing of both engines -import os - -# If `VLLM_USE_V1=1` is set upon first vLLM import, then there is a side effect -# that will cause the V1 engine to always be selected. This is intentionally -# done for backwards-compatibility of code that was using the AsyncLLMEngine -# constructor directly, instead of using the `.from_engine_args` construction -# methods that will select the appropriate v0 or v1 engine. See: -# https://github.com/vllm-project/vllm/blob/v0.8.4/vllm/engine/llm_engine.py#L2169-L2171 -# Deleting VLLM_USE_V1 here before importing vLLM allows us to continue testing -# both engines. -if "VLLM_USE_V1" in os.environ: - del os.environ["VLLM_USE_V1"] -# 🌶️🌶️🌶️ end hack - import hashlib +import os import random import pytest @@ -98,8 +84,7 @@ def remote_openai_server(request): max_num_seqs = params["max_num_seqs"] env_dict = { "VLLM_SPYRE_USE_CB": "1", - "VLLM_SPYRE_DYNAMO_BACKEND": backend, - "VLLM_USE_V1": "1" + "VLLM_SPYRE_DYNAMO_BACKEND": backend } server_args = [ "--max_num_seqs", @@ -121,8 +106,6 @@ def remote_openai_server(request): ','.join(map(str, warmup_batch_size)), "VLLM_SPYRE_DYNAMO_BACKEND": backend, - "VLLM_USE_V1": - "1" } # Default to None if not present diff --git a/tests/e2e/test_spyre_async_llm.py b/tests/e2e/test_spyre_async_llm.py index 7c8c30725..8fbe9ca5a 100644 --- a/tests/e2e/test_spyre_async_llm.py +++ b/tests/e2e/test_spyre_async_llm.py @@ -65,7 +65,6 @@ async def test_abort( ): """Test handling of cancelled requests""" with monkeypatch.context() as m, ExitStack() as after: - m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend) if cb == 1: m.setenv("VLLM_SPYRE_USE_CB", "1") diff --git a/tests/e2e/test_spyre_basic.py b/tests/e2e/test_spyre_basic.py index 2a27b9558..d02a1ed3b 100644 --- a/tests/e2e/test_spyre_basic.py +++ b/tests/e2e/test_spyre_basic.py @@ -205,7 +205,6 @@ def test_full_batch_scheduling(model: str, backend: str, monkeypatch): f"{max_batched_tokens}") monkeypatch.setenv("VLLM_SPYRE_WARMUP_NEW_TOKENS", "20") - monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend) # Setup the engine diff --git a/tests/e2e/test_spyre_embeddings.py b/tests/e2e/test_spyre_embeddings.py index 7e62b3c57..49c0ab2c9 100644 --- a/tests/e2e/test_spyre_embeddings.py +++ b/tests/e2e/test_spyre_embeddings.py @@ -18,12 +18,10 @@ [(64, 4), (64, 8), (128, 4), (128, 8)]) # (prompt_length/batch_size) @pytest.mark.parametrize("backend", get_spyre_backend_list()) -@pytest.mark.parametrize("vllm_version", ["V0", "V1"]) def test_output( model: str, warmup_shape: tuple[int, int], backend: str, - vllm_version: str, monkeypatch, ) -> None: ''' @@ -34,7 +32,6 @@ def test_output( ''' monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend) - monkeypatch.setenv("VLLM_USE_V1", "1" if vllm_version == "V1" else "0") patch_warmup_shapes([warmup_shape], monkeypatch) prompts = get_chicken_soup_prompts(1) @@ -44,8 +41,7 @@ def test_output( max_model_len=256, block_size=256, tensor_parallel_size=1, - backend=backend, - vllm_version=vllm_version) + backend=backend) hf_results = st_embeddings(model=model, prompts=prompts) @@ -65,12 +61,10 @@ def test_output( ]) # (prompt_length/batch_size) @pytest.mark.parametrize("backend", get_spyre_backend_list()) @pytest.mark.parametrize("model", get_spyre_model_list(isEmbeddings=True)) -@pytest.mark.parametrize("vllm_version", ["V0", "V1"]) def test_scheduling_invariance( model, backend, warmup_shape: tuple[int, int], - vllm_version, monkeypatch, ) -> None: ''' @@ -83,7 +77,6 @@ def test_scheduling_invariance( ''' monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend) - monkeypatch.setenv("VLLM_USE_V1", "1" if vllm_version == "V1" else "0") patch_warmup_shapes([warmup_shape], monkeypatch) prompts = get_chicken_soup_prompts(4) diff --git a/tests/e2e/test_spyre_prompt_logprobs.py b/tests/e2e/test_spyre_prompt_logprobs.py index 681031a4e..c237ffff5 100644 --- a/tests/e2e/test_spyre_prompt_logprobs.py +++ b/tests/e2e/test_spyre_prompt_logprobs.py @@ -40,7 +40,6 @@ def test_prompt_logprobs( prompts = get_chicken_soup_prompts(4) - monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend) monkeypatch.setenv("VLLM_SPYRE_ENABLE_PROMPT_LOGPROBS", "1") llm = LLM(model, tensor_parallel_size=tp_size, tokenizer=model) diff --git a/tests/e2e/test_spyre_static_batching_limits.py b/tests/e2e/test_spyre_static_batching_limits.py index d62d5bced..633cf6a65 100644 --- a/tests/e2e/test_spyre_static_batching_limits.py +++ b/tests/e2e/test_spyre_static_batching_limits.py @@ -30,7 +30,6 @@ def test_max_prompt_len_and_new_tokens(model: str, ''' monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend) patch_warmup_shapes(warmup_shapes, monkeypatch) - monkeypatch.setenv("VLLM_USE_V1", "1") max_prompt_length = max([t[0] for t in warmup_shapes]) max_new_tokens = max([t[1] for t in warmup_shapes]) diff --git a/tests/scheduling_utils.py b/tests/scheduling_utils.py index 957b7fe1e..843605246 100644 --- a/tests/scheduling_utils.py +++ b/tests/scheduling_utils.py @@ -56,7 +56,6 @@ def check_scheduler_inference_steps( """ # set env vars - monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend) if use_cb: monkeypatch.setenv("VLLM_SPYRE_USE_CB", "1") diff --git a/tests/spyre_util.py b/tests/spyre_util.py index 18ab37eac..a1e5d7a4b 100644 --- a/tests/spyre_util.py +++ b/tests/spyre_util.py @@ -203,7 +203,6 @@ def generate_spyre_vllm_output( ",".join(str(val) for val in warmup_batch_size)) # -------------- monkeypatch.setenv("VLLM_SPYRE_USE_CB", "1" if use_cb else "0") - monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend) # Allows to run multiprocess V1 engine without dumping meaningless logs at @@ -417,8 +416,7 @@ def check_output_against_hf(model, backend, max_new_tokens, vllm_results, # vLLM / Spyre def spyre_vllm_embeddings(model: str, prompts: list[str], max_model_len: int, block_size: int, tensor_parallel_size: int, - backend: str, - vllm_version: str) -> list[dict[str, Any]]: + backend: str) -> list[dict[str, Any]]: vllm_model = LLM(model=model, tokenizer=model, diff --git a/tests/v1/worker/test_spyre_input_batch.py b/tests/v1/worker/test_spyre_input_batch.py index 3b1e000a7..b33c611a8 100644 --- a/tests/v1/worker/test_spyre_input_batch.py +++ b/tests/v1/worker/test_spyre_input_batch.py @@ -211,7 +211,6 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: sampling_metadata.bad_words_token_ids -@pytest.mark.v1 @pytest.mark.worker @pytest.mark.parametrize("batch_size", [1, 2, 32, 64]) def test_sampling_metadata_in_input_batch(batch_size: int): diff --git a/vllm_spyre/core/__init__.py b/vllm_spyre/core/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/vllm_spyre/core/scheduler.py b/vllm_spyre/core/scheduler.py deleted file mode 100644 index 2dc740c98..000000000 --- a/vllm_spyre/core/scheduler.py +++ /dev/null @@ -1,1730 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import random -import time -from collections import deque -from collections.abc import Iterable -from typing import Callable, Optional, Union - -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -from vllm.core.interfaces import AllocStatus, BlockSpaceManager -# SPYRE SPECIFIC CODE BLOCK START -# yapf: disable -from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, - ARTIFICIAL_PREEMPTION_PROB, - ENABLE_ARTIFICIAL_PREEMPT, - PartialPrefillMetadata, PreemptionMode, - ScheduledSequenceGroup, SchedulerOutputs, - SchedulerPrefillOutputs, - SchedulerRunningOutputs, - SchedulerSwappedInOutputs, SchedulingBudget, - scheduled_seq_group_builder, - scheduler_running_outputs_builder, - seq_group_metadata_builder) -from vllm.logger import init_logger -# SPYRE SPECIFIC CODE BLOCK END -from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupBase, SequenceGroupMetadata, - SequenceGroupMetadataDelta, SequenceStatus) -from vllm.utils import Device, PyObjectCache - -from vllm_spyre.platform import SpyrePlatform - -logger = init_logger(__name__) - - -class SpyreScheduler: - - def __init__( - self, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - lora_config: Optional[LoRAConfig], - pipeline_parallel_size: int = 1, - output_proc_callback: Optional[Callable] = None, - ) -> None: - self.scheduler_config = scheduler_config - self.cache_config = cache_config - # Note for LoRA scheduling: the current policy is extremely - # simple and NOT fair. It can lead to starvation of some - # LoRAs. This should be improved in the future. - self.lora_config = lora_config - - version = "selfattn" - if (self.scheduler_config.runner_type == "pooling" - or self.cache_config.is_attention_free): - version = "placeholder" - - BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( - version) - - num_gpu_blocks = cache_config.num_gpu_blocks - if num_gpu_blocks: - num_gpu_blocks //= pipeline_parallel_size - - num_cpu_blocks = cache_config.num_cpu_blocks - if num_cpu_blocks: - num_cpu_blocks //= pipeline_parallel_size - - # Create the block space manager. - self.block_manager = BlockSpaceManagerImpl( - block_size=self.cache_config.block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - sliding_window=self.cache_config.sliding_window, - enable_caching=self.cache_config.enable_prefix_caching, - ) - - # Sequence groups in the WAITING state. - # Contain new prefill or preempted requests. - self.waiting: deque[SequenceGroup] = deque() - # Sequence groups in the RUNNING state. - # Contain decode requests. - self.running: deque[SequenceGroup] = deque() - # Sequence groups in the SWAPPED state. - # Contain decode requests that are swapped out. - self.swapped: deque[SequenceGroup] = deque() - # Sequence groups finished requests ids since last step iteration. - # It lets the model know that any state associated with these requests - # can and must be released after the current step. - # This is used to evict the finished requests from the Mamba cache. - self._finished_requests_ids: list[str] = list() - # Time at previous scheduling step - self.prev_time = 0.0 - # Did we schedule a prompt at previous step? - self.prev_prompt = False - # Latency of the last prompt step - self.last_prompt_latency = 0.0 - # preemption mode, RECOMPUTE or SWAP - self.user_specified_preemption_mode = scheduler_config.preemption_mode - - # The following field is test-only. It is used to inject artificial - # preemption. - self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT - self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT - if self.enable_artificial_preemption - else 0) - self.num_cumulative_preemption: int = 0 - - # Used to cache python objects - self._seq_group_metadata_cache: list[PyObjectCache] = [] - self._scheduler_running_outputs_cache: list[PyObjectCache] = [] - self._scheduled_seq_group_cache: list[PyObjectCache] = [] - - # For async output processing, we need to swap cache buffers between - # iterations. I.e. since the output processing is lagged one step, - # we cannot reuse the cached objects immediately when the schedule() - # is called again, but only when schedule() is called the second time. - self.output_proc_callback = output_proc_callback - self.use_async_output_proc = self.output_proc_callback is not None - self.num_cache_iters = 2 if self.use_async_output_proc else 1 - - self.cache_id = 0 - for i in range(self.num_cache_iters): - self._seq_group_metadata_cache.append( - PyObjectCache(seq_group_metadata_builder)) - self._scheduler_running_outputs_cache.append( - PyObjectCache(scheduler_running_outputs_builder)) - self._scheduled_seq_group_cache.append( - PyObjectCache(scheduled_seq_group_builder)) - - # For async postprocessor, the extra decode run cannot be done - # when the request reaches max_model_len. In this case, the request - # will be stopped during schedule() call and added to this stop list - # for processing and deallocation by the free_finished_seq_groups() - self._async_stopped: list[SequenceGroup] = [] - - # List with the chunk sizes to hand out to each sequence depending - # on how many partial prefills are running. This is slightly faster than - # running an integer division every time a prefill is scheduled. - # This splits the budget evenly among all prefills. - self.partial_prefill_budget_lookup_list = [0] * ( - self.scheduler_config.max_num_partial_prefills + 1) - self.partial_prefill_budget_lookup_list[0] = ( - scheduler_config.max_num_batched_tokens) - for i in range(1, self.scheduler_config.max_num_partial_prefills + 1): - self.partial_prefill_budget_lookup_list[i] = ( - scheduler_config.max_num_batched_tokens // i) - - # SPYRE SPECIFIC CODE BLOCK START - self.spyre_warmup_shapes = SpyrePlatform.get_warmup_shapes( - self.scheduler_config - ) - # SPYRE SPECIFIC CODE BLOCK END - - @property - def next_cache_id(self): - return (self.cache_id + 1) % self.num_cache_iters - - @property - def lora_enabled(self) -> bool: - return bool(self.lora_config) - - @property - def num_decoding_tokens_per_seq(self) -> int: - """The number of new tokens.""" - return 1 - - def add_seq_group(self, seq_group: SequenceGroup) -> None: - # Add sequence groups to the waiting queue. - self.waiting.append(seq_group) - - def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: - # Add sequence groups to the running queue. - # Only for testing purposes. - self.running.append(seq_group) - - def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None: - # Add sequence groups to the swapped queue. - # Only for testing purposes. - self.swapped.append(seq_group) - - def abort_seq_group( - self, - request_id: Union[str, Iterable[str]], - seq_id_to_seq_group: Optional[dict[str, SequenceGroupBase]] = None, - ) -> None: - """Aborts a sequence group with the given ID. - - Check if the sequence group with the given ID - is present in any of the state queue. - If present, remove the sequence group from the state queue. - Also, if any of the sequences in the sequence group is not finished, - free the sequence with status `FINISHED_ABORTED`. - Otherwise, do nothing. - - Args: - request_id: The ID(s) of the sequence group to abort. - """ - if isinstance(request_id, str): - request_id = (request_id, ) - request_ids = set(request_id) - for state_queue in [self.waiting, self.running, self.swapped]: - aborted_groups: list[SequenceGroup] = [] - for seq_group in state_queue: - if not request_ids: - # Using 'break' here may add two extra iterations, - # but is acceptable to reduce complexity. - break - if seq_group.request_id in request_ids: - # Appending aborted group into pending list. - aborted_groups.append(seq_group) - request_ids.remove(seq_group.request_id) - for aborted_group in aborted_groups: - # Remove the sequence group from the state queue. - state_queue.remove(aborted_group) - # Remove the aborted request from the Mamba cache. - self._finished_requests_ids.append(aborted_group.request_id) - for seq in aborted_group.get_seqs(): - if seq.is_finished(): - continue - seq.status = SequenceStatus.FINISHED_ABORTED - self.free_seq(seq) - - self._free_seq_group_cross_attn_blocks(aborted_group) - - def _free_seq_group_cross_attn_blocks( - self, - seq_group: SequenceGroup, - ) -> None: - """ - Free a sequence group from a cross-attention block table. - Has no effect on decoder-only models. - """ - if seq_group.is_encoder_decoder(): - self.block_manager.free_cross(seq_group) - - def has_unfinished_seqs(self) -> bool: - return (len(self.waiting) != 0 or len(self.running) != 0 - or len(self.swapped) != 0) - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - return self.block_manager.get_prefix_cache_hit_rate(device) - - def reset_prefix_cache(self) -> bool: - return self.block_manager.reset_prefix_cache() - - def get_num_unfinished_seq_groups(self) -> int: - return len(self.waiting) + len(self.running) + len(self.swapped) - - def get_and_reset_finished_requests_ids(self) -> list[str]: - """Flushes the list of request ids of previously finished seq_groups.""" - finished_requests_ids = self._finished_requests_ids - self._finished_requests_ids = list() - return finished_requests_ids - - def _schedule_running( - self, - budget: SchedulingBudget, - curr_loras: Optional[set[int]], - enable_chunking: bool = False, - partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, - ) -> SchedulerRunningOutputs: - """Schedule sequence groups that are running. - - Running queue should include decode and chunked prefill requests. - - Args: - budget: The scheduling budget. The argument is in-place updated - when any decodes are preempted. - curr_loras: Currently batched lora request ids. The argument is - in-place updated when any decodes are preempted. - enable_chunking: If True, seq group can be chunked and only a - chunked number of tokens are scheduled if - `budget.num_batched_tokens` has not enough capacity to schedule - all tokens. - partial_prefill_metadata: information about the partial prefills - that are currently running - - Returns: - SchedulerRunningOutputs. - """ - ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[ - self.cache_id].get_object() - ret.blocks_to_swap_out.clear() - ret.blocks_to_copy.clear() - ret.decode_seq_groups.clear() - ret.prefill_seq_groups.clear() - ret.preempted.clear() - ret.swapped_out.clear() - - ret.num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill=False, enable_chunking=enable_chunking) - - ret.decode_seq_groups_list.clear() - ret.prefill_seq_groups_list.clear() - - # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_out: list[tuple[int, int]] = ret.blocks_to_swap_out - blocks_to_copy: list[tuple[int, int]] = ret.blocks_to_copy - - decode_seq_groups: list[ScheduledSequenceGroup] = ret.decode_seq_groups - prefill_seq_groups: list[ - ScheduledSequenceGroup] = ret.prefill_seq_groups - preempted: list[SequenceGroup] = ret.preempted - swapped_out: list[SequenceGroup] = ret.swapped_out - - running_queue = self.running - assert len(self._async_stopped) == 0 - while running_queue: - seq_group = running_queue[0] - # We discard the cached tokens info here because we don't need it - # for running sequence: - # 1. If a sequence is running with chunked prefill, the cached - # tokens info was already used for the first prefill. - # 2. If a sequence is running with non-chunked prefill, then - # there it's a decoding sequence, and the cached tokens info is - # irrelevant. - num_uncached_new_tokens, _ = \ - self._get_num_new_uncached_and_cached_tokens( - seq_group, - SequenceStatus.RUNNING, - enable_chunking, - budget, - partial_prefill_metadata, - ) - - num_running_tokens = num_uncached_new_tokens - if num_running_tokens == 0: - # No budget => Stop - break - - running_queue.popleft() - - # With async postprocessor, an extra decode run is done - # to process the final tokens. The check below avoids this extra - # decode run when the model max len is reached, in order to avoid - # a memory overflow. - if (self.use_async_output_proc and seq_group.seqs[0].get_len() - > self.scheduler_config.max_model_len): - self._async_stopped.append(seq_group) - continue - - # NOTE(woosuk): Preemption happens only when there is no available - # slot to keep all the sequence groups in the RUNNING state. - while not self._can_append_slots(seq_group, enable_chunking): - budget.subtract_num_batched_tokens(seq_group.request_id, - num_running_tokens) - num_running_seqs = seq_group.get_max_num_running_seqs() - budget.subtract_num_seqs(seq_group.request_id, - num_running_seqs) - - if (curr_loras is not None and seq_group.lora_int_id > 0 - and seq_group.lora_int_id in curr_loras): - curr_loras.remove(seq_group.lora_int_id) - - # Determine victim sequence - cont_loop = True - if running_queue: - # Preempt the lowest-priority sequence group. - victim_seq_group = running_queue.pop() - else: - # No other sequence group can be preempted. - # Preempt the current sequence group. - # Note: This is also where we stop this loop - # (since there is nothing else to preempt) - victim_seq_group = seq_group - cont_loop = False - - # With async postprocessor, before preempting a sequence - # we need to ensure it has no pending async postprocessor - do_preempt = True - if self.use_async_output_proc: - assert self.output_proc_callback is not None - self.output_proc_callback( - request_id=victim_seq_group.request_id) - - # It may be that the async pending "victim_seq_group" - # becomes finished, in which case we simply free it. - if victim_seq_group.is_finished(): - self._free_finished_seq_group(victim_seq_group) - do_preempt = False - - # Do preemption - if do_preempt: - preempted_mode = self._preempt(victim_seq_group, - blocks_to_swap_out) - if preempted_mode == PreemptionMode.RECOMPUTE: - preempted.append(victim_seq_group) - else: - swapped_out.append(victim_seq_group) - - if not cont_loop: - break - else: - self._append_slots(seq_group, blocks_to_copy, enable_chunking) - is_prefill = seq_group.is_prefill() - - scheduled_seq_group: ScheduledSequenceGroup = ( - self._scheduled_seq_group_cache[ - self.cache_id].get_object()) - scheduled_seq_group.seq_group = seq_group - if is_prefill: - scheduled_seq_group.token_chunk_size = num_running_tokens - prefill_seq_groups.append(scheduled_seq_group) - ret.prefill_seq_groups_list.append(seq_group) - else: - scheduled_seq_group.token_chunk_size = 1 - decode_seq_groups.append(scheduled_seq_group) - ret.decode_seq_groups_list.append(seq_group) - - budget.add_num_batched_tokens(seq_group.request_id, - num_running_tokens) - # OPTIMIZATION: Note that get_max_num_running_seqs is - # expensive. For the default scheduling chase where - # enable_chunking is False, num_seqs are updated before running - # this method, so we don't have to update it again here. - if enable_chunking: - num_running_seqs = seq_group.get_max_num_running_seqs() - budget.add_num_seqs(seq_group.request_id, num_running_seqs) - if curr_loras is not None and seq_group.lora_int_id > 0: - curr_loras.add(seq_group.lora_int_id) - - self._scheduler_running_outputs_cache[self.next_cache_id].reset() - self._scheduled_seq_group_cache[self.next_cache_id].reset() - - return ret - - def _schedule_swapped( - self, - budget: SchedulingBudget, - curr_loras: Optional[set[int]], - enable_chunking: bool = False, - ) -> SchedulerSwappedInOutputs: - """Schedule sequence groups that are swapped out. - - It schedules swapped requests as long as it fits `budget` and - curr_loras <= max_lora from the scheduling config. The input arguments - `budget` and `curr_loras` are updated based on scheduled seq_groups. - - Args: - budget: The scheduling budget. The argument is in-place updated - when any requests are swapped in. - curr_loras: Currently batched lora request ids. The argument is - in-place updated when any requests are swapped in. - enable_chunking: If True, seq group can be chunked and only a - chunked number of tokens are scheduled if - `budget.num_batched_tokens` has not enough capacity to schedule - all tokens. - - Returns: - SchedulerSwappedInOutputs. - """ - # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_in: list[tuple[int, int]] = [] - blocks_to_copy: list[tuple[int, int]] = [] - decode_seq_groups: list[ScheduledSequenceGroup] = [] - prefill_seq_groups: list[ScheduledSequenceGroup] = [] - infeasible_seq_groups: list[SequenceGroup] = [] - - swapped_queue = self.swapped - - leftover_swapped: deque[SequenceGroup] = deque() - while swapped_queue: - seq_group = swapped_queue[0] - - # If the sequence group cannot be swapped in, stop. - is_prefill = seq_group.is_prefill() - alloc_status = self.block_manager.can_swap_in( - seq_group, - self._get_num_lookahead_slots(is_prefill, enable_chunking)) - if alloc_status == AllocStatus.LATER: - break - elif alloc_status == AllocStatus.NEVER: - logger.warning( - "Failing the request %s because there's not enough kv " - "cache blocks to run the entire sequence.", - seq_group.request_id, - ) - for seq in seq_group.get_seqs(): - seq.status = SequenceStatus.FINISHED_IGNORED - infeasible_seq_groups.append(seq_group) - swapped_queue.popleft() - continue - - lora_int_id = 0 - if self.lora_enabled: - lora_int_id = seq_group.lora_int_id - assert curr_loras is not None - assert self.lora_config is not None - if (lora_int_id > 0 and (lora_int_id not in curr_loras) - and len(curr_loras) >= self.lora_config.max_loras): - # We don't have a space for another LoRA, so - # we ignore this request for now. - leftover_swapped.appendleft(seq_group) - swapped_queue.popleft() - continue - - # The total number of sequences in the RUNNING state should not - # exceed the maximum number of sequences. - num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens_uncached, num_new_tokens_cached = ( - self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.SWAPPED, enable_chunking, - budget)) - - if num_new_tokens_uncached == 0 or not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - ): - break - - if lora_int_id > 0 and curr_loras is not None: - curr_loras.add(lora_int_id) - swapped_queue.popleft() - self._swap_in(seq_group, blocks_to_swap_in) - self._append_slots(seq_group, blocks_to_copy, enable_chunking) - if is_prefill: - prefill_seq_groups.append( - ScheduledSequenceGroup( - seq_group, - token_chunk_size=num_new_tokens_uncached + - num_new_tokens_cached, - )) - else: - decode_seq_groups.append( - ScheduledSequenceGroup(seq_group, token_chunk_size=1)) - budget.add_num_batched_tokens( - seq_group.request_id, - num_batched_tokens=num_new_tokens_uncached, - num_cached_tokens=num_new_tokens_cached, - ) - budget.add_num_seqs(seq_group.request_id, num_new_seqs) - - swapped_queue.extendleft(leftover_swapped) - - return SchedulerSwappedInOutputs( - decode_seq_groups=decode_seq_groups, - prefill_seq_groups=prefill_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_copy=blocks_to_copy, - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False, enable_chunking=enable_chunking), - infeasible_seq_groups=infeasible_seq_groups, - ) - - def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: - if (self.scheduler_config.chunked_prefill_enabled - and not self.scheduler_config.is_multi_step): - prompt_limit = self.scheduler_config.max_model_len - else: - prompt_limit = min( - self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens, - ) - - # Model is fine tuned with long context. Return the fine tuned max_len. - if seq_group.lora_request and seq_group.lora_request.long_lora_max_len: - assert prompt_limit <= seq_group.lora_request.long_lora_max_len - return seq_group.lora_request.long_lora_max_len - else: - return prompt_limit - - def _get_priority(self, - seq_group: SequenceGroup) -> tuple[Optional[int], float]: - """Get the priority of the sequence group. - Highest preference to user-defined priority, followed by arrival time. - Args: - seq_group: The sequence group input. - Returns: - The priority of the sequence group. - """ - return seq_group.priority, seq_group.arrival_time - - def _schedule_priority_preemption( - self, - budget: SchedulingBudget, - ) -> int: - """Sorts waiting and running queue. Also, force preempt requests - from the running queue if their priority is lower. - Priority-based preemption is used with the priority policy. - Args: - budget: The scheduling budget. The argument is in-place updated - when any requests are scheduled. - Returns: - A count of priority-based preemptions. - """ - - waiting_queue = self.waiting - - running_queue = deque(sorted(self.running, key=self._get_priority)) - - blocks_to_swap_out: list[tuple[int, int]] = [] - force_preemption_count = 0 - - if waiting_queue: - seq_group = waiting_queue.popleft() - num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens_uncached, _ = \ - self._get_num_new_uncached_and_cached_tokens( - seq_group, SequenceStatus.WAITING, False, budget) - - # Only preempt if priority inversion exists - while running_queue and self._get_priority( - running_queue[-1]) > self._get_priority(seq_group): - # Only preempt if waiting sequence cannot be allocated - can_allocate = self.block_manager.can_allocate(seq_group) - if (num_new_tokens_uncached > 0 - and can_allocate == AllocStatus.OK - and budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - )): - break - - # Adjust budget to remove the victim sequence group - vseq_group = running_queue.pop() - num_running_tokens_uncached, _ = ( - self._get_num_new_uncached_and_cached_tokens( - vseq_group, SequenceStatus.RUNNING, False, budget)) - budget.subtract_num_batched_tokens( - vseq_group.request_id, num_running_tokens_uncached) - num_running_seqs = vseq_group.get_max_num_running_seqs() - budget.subtract_num_seqs(vseq_group.request_id, - num_running_seqs) - - # Preempt out the victim sequence group - self._preempt(vseq_group, blocks_to_swap_out) - waiting_queue.appendleft(vseq_group) - force_preemption_count += 1 - # Put the sequence back into the waiting queue - waiting_queue.appendleft(seq_group) - - waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) - - self.waiting = waiting_queue - self.running = running_queue - return force_preemption_count - - def _schedule_prefills( - self, - budget: SchedulingBudget, - curr_loras: Optional[set[int]], - enable_chunking: bool = False, - partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, - ) -> SchedulerPrefillOutputs: - """Schedule sequence groups that are in prefill stage. - - Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE - as a new prefill (that starts from beginning -> most recently generated - tokens). - - It schedules waiting requests as long as it fits `budget` and - curr_loras <= max_lora from the scheduling config. The input arguments - `budget` and `curr_loras` are updated based on scheduled seq_groups. - - Args: - budget: The scheduling budget. The argument is in-place updated - when any requests are scheduled. - curr_loras: Currently batched lora request ids. The argument is - in-place updated when any requests are scheduled. - enable_chunking: If True, seq group can be chunked and only a - chunked number of tokens are scheduled if - `budget.num_batched_tokens` has not enough capacity to schedule - all tokens. - partial_prefill_metadata: information about the partial prefills - that are currently running - - Returns: - SchedulerPrefillOutputs. - """ - if budget.remaining_token_budget() == 0: - # Do nothing: Can't add any more prefill anyway - return SchedulerPrefillOutputs( - seq_groups=[], - ignored_seq_groups=[], - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking), - ) - ignored_seq_groups: list[SequenceGroup] = [] - seq_groups: list[ScheduledSequenceGroup] = [] - - # SPYRE SPECIFIC CODE BLOCK START - applicable_spyre_warmup_shapes = list(self.spyre_warmup_shapes) - # SPYRE SPECIFIC CODE BLOCK END - - waiting_queue = self.waiting - - leftover_waiting_sequences: deque[SequenceGroup] = deque() - while self._passed_delay(time.time()) and waiting_queue: - seq_group = waiting_queue[0] - - waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) - assert len(waiting_seqs) == 1, ( - "Waiting sequence group should have only one prompt " - "sequence.") - if (partial_prefill_metadata is not None - and not partial_prefill_metadata.can_schedule(seq_group)): - leftover_waiting_sequences.appendleft(seq_group) - waiting_queue.popleft() - continue - num_new_tokens_uncached, num_new_tokens_cached = ( - self._get_num_new_uncached_and_cached_tokens( - seq_group, - SequenceStatus.WAITING, - enable_chunking, - budget, - partial_prefill_metadata=partial_prefill_metadata, - )) - num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached - - if not enable_chunking: - num_prompt_tokens = waiting_seqs[0].get_len() - assert num_new_tokens == num_prompt_tokens - - prompt_limit = self._get_prompt_limit(seq_group) - if num_new_tokens > prompt_limit: - logger.warning( - "Input prompt (%d tokens) is too long" - " and exceeds limit of %d", - num_new_tokens, - prompt_limit, - ) - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - ignored_seq_groups.append(seq_group) - waiting_queue.popleft() - continue - - num_lookahead_slots: int = 0 - if self.scheduler_config.is_multi_step and enable_chunking: - num_lookahead_slots = self._get_num_lookahead_slots( - True, enable_chunking) - - # If the sequence group cannot be allocated, stop. - can_allocate = self.block_manager.can_allocate( - seq_group, num_lookahead_slots=num_lookahead_slots) - if can_allocate == AllocStatus.LATER: - break - elif can_allocate == AllocStatus.NEVER: - logger.warning( - "Input prompt (%d tokens) + lookahead slots (%d) is " - "too long and exceeds the capacity of block_manager", - num_new_tokens, - num_lookahead_slots, - ) - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - ignored_seq_groups.append(seq_group) - waiting_queue.popleft() - continue - - lora_int_id = 0 - if self.lora_enabled: - lora_int_id = seq_group.lora_int_id - assert curr_loras is not None - assert self.lora_config is not None - if (self.lora_enabled and lora_int_id > 0 - and lora_int_id not in curr_loras - and len(curr_loras) >= self.lora_config.max_loras): - # We don't have a space for another LoRA, so - # we ignore this request for now. - leftover_waiting_sequences.appendleft(seq_group) - waiting_queue.popleft() - continue - - if (budget.num_batched_tokens - >= self.scheduler_config.max_num_batched_tokens): - # We've reached the budget limit - since there might be - # continuous prefills in the running queue, we should break - # to avoid scheduling any new prefills. - break - - num_new_seqs = seq_group.get_max_num_running_seqs() - if num_new_tokens_uncached == 0 or not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - ): - break - - # SPYRE SPECIFIC CODE BLOCK START - # check if current request can be scheduled based on the applicable - # spyre warmup shapes - max_tokens = 0 - if seq_group.sampling_params is not None and\ - seq_group.sampling_params.max_tokens is not None: - max_tokens = seq_group.sampling_params.max_tokens - updated_spyre_warmup_shapes = [ - shape for shape in applicable_spyre_warmup_shapes - if num_new_tokens <= shape['prompt_length'] and max_tokens <= - shape['new_tokens'] and len(seq_groups) < shape['batch_size'] - ] - if not updated_spyre_warmup_shapes: - if not seq_groups: - # request was tested against all spyre warmup shapes: - # request cannot be processed - if (seq_group.sampling_params is not None and - seq_group.sampling_params.max_tokens is not None): - logger.warning( - "No applicable warmup shape exists for " - "combination of prompt length (%d tokens) " - "and maximum number of output tokens to be " - "generated (%d tokens)", num_new_tokens, - seq_group.sampling_params.max_tokens) - else: - logger.warning( - "No applicable warmup shape exists for " - "combination of prompt length (%d tokens) " - "and undefined maximum number of output " - "tokens", num_new_tokens) - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - ignored_seq_groups.append(seq_group) - waiting_queue.popleft() - continue - else: - # request was only tested against spyre warmup shapes - # that remain after processing previous requests in - # waiting queue: request will be evaluated again in - # a future scheduling step - leftover_waiting_sequences.appendleft(seq_group) - waiting_queue.popleft() - continue - else: - applicable_spyre_warmup_shapes = updated_spyre_warmup_shapes - # SPYRE SPECIFIC CODE BLOCK END - - # Can schedule this request. - if curr_loras is not None and lora_int_id > 0: - curr_loras.add(lora_int_id) - waiting_queue.popleft() - self._allocate_and_set_running(seq_group) - - if partial_prefill_metadata is not None: - partial_prefill_metadata.maybe_increment_partial_prefills( - seq_group) - - if enable_chunking and self.scheduler_config.is_multi_step: - blocks_to_copy: list[tuple[int, int]] = [] - # init_multi_step_from_lookahead_slots happens in append_slots - self._append_slots(seq_group, blocks_to_copy, enable_chunking) - # This assert will trip when a copy-on-write happens. This is - # not a concern as the very first sequence-group block - # allocation happens above. Still, we have the assert to - # catch any edge-cases. - assert not blocks_to_copy - else: - seq_group.init_multi_step_from_lookahead_slots( - num_lookahead_slots, - num_scheduler_steps=self.scheduler_config. - num_scheduler_steps, - is_multi_step=self.scheduler_config.is_multi_step, - enable_chunking=enable_chunking, - ) - - seq_groups.append( - ScheduledSequenceGroup(seq_group=seq_group, - token_chunk_size=num_new_tokens)) - budget.add_num_batched_tokens( - seq_group.request_id, - num_batched_tokens=num_new_tokens_uncached, - num_cached_tokens=num_new_tokens_cached, - ) - budget.add_num_seqs(seq_group.request_id, num_new_seqs) - - # SPYRE SPECIFIC CODE BLOCK START - # Check if number of scheduled requests has reached the maximum - # batch size of the applicable warmup shapes - if len(seq_groups) >= max([ - shape['batch_size'] - for shape in applicable_spyre_warmup_shapes - ]): - break - # SPYRE SPECIFIC CODE BLOCK END - - # Queue requests that couldn't be scheduled. - waiting_queue.extendleft(leftover_waiting_sequences) - if len(seq_groups) > 0: - self.prev_prompt = True - - return SchedulerPrefillOutputs( - seq_groups=seq_groups, - ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking), - ) - - def _schedule_default(self) -> SchedulerOutputs: - """Schedule queued requests. - - The current policy is designed to optimize the throughput. First, - it batches as many prefill requests as possible. And it schedules - decodes. If there's a pressure on GPU memory, decode requests can - be swapped or preempted. - """ - # Include running requests to the budget. - budget = SchedulingBudget( - token_budget=self.scheduler_config.max_num_batched_tokens, - max_num_seqs=self.scheduler_config.max_num_seqs, - ) - # Make sure we include num running seqs before scheduling prefill, - # so that we don't schedule beyond max_num_seqs for prefill. - for seq_group in self.running: - budget.add_num_seqs(seq_group.request_id, - seq_group.get_max_num_running_seqs()) - curr_loras = (set( - seq_group.lora_int_id for seq_group in self.running - if seq_group.lora_int_id > 0) if self.lora_enabled else None) - - prefills = SchedulerPrefillOutputs.create_empty() - running_scheduled = SchedulerRunningOutputs.create_empty() - swapped_in = SchedulerSwappedInOutputs.create_empty() - - # SPYRE SPECIFIC CODE BLOCK START - # Schedule new prefills only when no requests have been swapped - # and all previous decodes have completed (Spyre constraint). - if not self.swapped and not self.running: - prefills = self._schedule_prefills(budget, - curr_loras, - enable_chunking=False) - # SPYRE SPECIFIC CODE BLOCK END - - if len(prefills.seq_groups - ) == 0 and self.scheduler_config.policy == "priority": - self._schedule_priority_preemption(budget) - - # Don't schedule decodes if prefills are scheduled. - # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running - # only contains decode requests, not chunked prefills. - if len(prefills.seq_groups) == 0: - running_scheduled = self._schedule_running(budget, - curr_loras, - enable_chunking=False) - - # If any sequence group is preempted, do not swap in any sequence - # group. because it means there's no slot for new running requests. - if (len(running_scheduled.preempted) + - len(running_scheduled.swapped_out) == 0): - swapped_in = \ - self._schedule_swapped(budget, curr_loras) - - assert (budget.num_batched_tokens - <= self.scheduler_config.max_num_batched_tokens) - assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs - - # Update waiting requests. - self.waiting.extendleft(running_scheduled.preempted) - # Update new running requests. - if len(prefills.seq_groups) > 0: - self.running.extend([s.seq_group for s in prefills.seq_groups]) - - self.running.extend(running_scheduled.decode_seq_groups_list) - - if len(swapped_in.decode_seq_groups) > 0: - self.running.extend( - [s.seq_group for s in swapped_in.decode_seq_groups]) - - # Update swapped requests. - self.swapped.extend(running_scheduled.swapped_out) - preempted = len(running_scheduled.preempted) + len( - running_scheduled.swapped_out) - - # There should be no prefill from running queue because this policy - # doesn't allow chunked prefills. - assert len(running_scheduled.prefill_seq_groups) == 0 - assert len(swapped_in.prefill_seq_groups) == 0 - - # Merge lists - num_prefill_groups = len(prefills.seq_groups) - if num_prefill_groups > 0: - scheduled_seq_groups = prefills.seq_groups - scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) - else: - scheduled_seq_groups = running_scheduled.decode_seq_groups - scheduled_seq_groups.extend(swapped_in.decode_seq_groups) - - blocks_to_copy = running_scheduled.blocks_to_copy - blocks_to_copy.extend(swapped_in.blocks_to_copy) - - ignored_seq_groups = prefills.ignored_seq_groups - ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) - - return SchedulerOutputs( - scheduled_seq_groups=scheduled_seq_groups, - num_prefill_groups=num_prefill_groups, - num_batched_tokens=budget.num_batched_tokens + - budget.num_cached_tokens, - blocks_to_swap_in=swapped_in.blocks_to_swap_in, - blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=running_scheduled.num_lookahead_slots, - running_queue_size=len(self.running), - preempted=preempted, - ) - - def _schedule_chunked_prefill(self) -> SchedulerOutputs: - """Schedule queued requests. - - Chunked prefill allows to chunk prefill requests, batch them together - with decode requests. This policy 1. schedule as many decoding requests - as possible. 2. schedule chunked prefill requests that are not - finished. 3. schedule swapped request. 4. schedule new prefill - requests. - - The policy can sustain the high GPU utilization because it can put - prefill and decodes requests to the same batch, while it improves - inter token latency because decodes requests don't need to be blocked - by prefill requests. - """ - budget = SchedulingBudget( - token_budget=self.scheduler_config.max_num_batched_tokens, - max_num_seqs=self.scheduler_config.max_num_seqs, - ) - curr_loras: set[int] = set() - - prefills = SchedulerPrefillOutputs.create_empty() - swapped_in = SchedulerSwappedInOutputs.create_empty() - - # Create partial prefill metadata - partial_prefill_metadata = PartialPrefillMetadata.from_queues( - running=self.running, - waiting=self.waiting, - scheduler_config=self.scheduler_config, - ) - - # Decoding should be always scheduled first by fcfs. - running_scheduled = self._schedule_running( - budget, - curr_loras, - enable_chunking=True, - partial_prefill_metadata=partial_prefill_metadata, - ) - - # Schedule swapped out requests. - # If preemption happens, it means we don't have space for swap-in. - if len(running_scheduled.preempted) + len( - running_scheduled.swapped_out) == 0: - swapped_in = self._schedule_swapped(budget, curr_loras) - - prefills = self._schedule_prefills( - budget, - curr_loras, - enable_chunking=True, - partial_prefill_metadata=partial_prefill_metadata, - ) - - assert (budget.num_batched_tokens - <= self.scheduler_config.max_num_batched_tokens) - assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs - - # Update waiting requests. - self.waiting.extendleft(running_scheduled.preempted) - - # Update new running requests. - # By default, vLLM scheduler prioritizes prefills. - # Once chunked prefill is enabled, - # the policy is changed to prioritize decode requests. - self.running.extend( - [s.seq_group for s in swapped_in.decode_seq_groups]) - self.running.extend( - [s.seq_group for s in swapped_in.prefill_seq_groups]) - self.running.extend( - [s.seq_group for s in running_scheduled.decode_seq_groups]) - # Because multiple prefills may be running concurrently, we need to - # make sure that prefills which are scheduled to finish are listed - # before those that won't. This is so that on the next scheduling - # iteration when they have transitioned to the decode stage, they are - # properly prioritized over sequences that are still in the prefill - # stage. - self.running.extend( - self._order_finishing_prefills_first( - running_scheduled.prefill_seq_groups)) - self.running.extend([s.seq_group for s in prefills.seq_groups]) - - # Update swapped requests. - self.swapped.extend(running_scheduled.swapped_out) - # Put prefills first due to Attention backend ordering assumption. - scheduled_seq_groups = (prefills.seq_groups + - running_scheduled.prefill_seq_groups + - swapped_in.prefill_seq_groups + - running_scheduled.decode_seq_groups + - swapped_in.decode_seq_groups) - num_prefill_groups = (len(prefills.seq_groups) + - len(swapped_in.prefill_seq_groups) + - len(running_scheduled.prefill_seq_groups)) - # If all prompts, then we set num_lookahead_slots to 0 - # this allows us to go through the `no_spec` path in - # `spec_decode_worker.py` - all_prefills = len(scheduled_seq_groups) == num_prefill_groups - num_lookahead_slots = (0 if - (all_prefills - and not self.scheduler_config.is_multi_step) - else running_scheduled.num_lookahead_slots) - return SchedulerOutputs( - scheduled_seq_groups=scheduled_seq_groups, - num_prefill_groups=num_prefill_groups, - num_batched_tokens=budget.num_batched_tokens + - budget.num_cached_tokens, - blocks_to_swap_in=swapped_in.blocks_to_swap_in, - blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=running_scheduled.blocks_to_copy + - swapped_in.blocks_to_copy, - ignored_seq_groups=prefills.ignored_seq_groups + - swapped_in.infeasible_seq_groups, - num_lookahead_slots=num_lookahead_slots, - running_queue_size=len(self.running), - preempted=(len(running_scheduled.preempted) + - len(running_scheduled.swapped_out)), - ) - - def _order_finishing_prefills_first( - self, scheduled_prefill_seqs: list[ScheduledSequenceGroup] - ) -> list[SequenceGroup]: - """Returns a list of prefilling SequenceGroups where sequences that are - scheduled to finish prefilling are listed first""" - finishing = [ - s.seq_group for s in scheduled_prefill_seqs - if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size - ] - not_finishing = [ - s.seq_group for s in scheduled_prefill_seqs - if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size - ] - return finishing + not_finishing - - def _schedule(self) -> SchedulerOutputs: - """Schedule queued requests.""" - if self.scheduler_config.chunked_prefill_enabled: - return self._schedule_chunked_prefill() - else: - return self._schedule_default() - - def _can_append_slots(self, seq_group: SequenceGroup, - enable_chunking: bool) -> bool: - """Determine whether or not we have enough space in the KV cache to - continue generation of the sequence group. - """ - # It is True only for testing case to trigger artificial preemption. - if (self.enable_artificial_preemption - and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB - and self.artificial_preempt_cnt > 0): - self.artificial_preempt_cnt -= 1 - return False - - is_prefill = seq_group.is_prefill() - num_lookahead_slots = self._get_num_lookahead_slots( - is_prefill, enable_chunking) - - if is_prefill and num_lookahead_slots > 0: - # Appending prefill slots only happens multi-step and - # chunked-prefill are enabled together. - assert self.scheduler_config.is_multi_step and enable_chunking - - return self.block_manager.can_append_slots( - seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) - - def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: - # async_output_proc is allowed only when we have a single sequence - # in the sequence group - no_single_seq = seq_group.sampling_params is None or ( - seq_group.sampling_params.n == 1) - return no_single_seq - - def schedule( - self - ) -> tuple[list[SequenceGroupMetadata], SchedulerOutputs, bool]: - # Schedule sequence groups. - # This function call changes the internal states of the scheduler - # such as self.running, self.swapped, and self.waiting. - scheduler_start_time = time.perf_counter() - - scheduler_outputs: SchedulerOutputs = self._schedule() - now = time.time() - - if not self.cache_config.enable_prefix_caching: - common_computed_block_nums = [] - - allow_async_output_proc: bool = self.use_async_output_proc - - # Create input data structures. - seq_group_metadata_list: list[SequenceGroupMetadata] = [] - for i, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): - seq_group = scheduled_seq_group.seq_group - token_chunk_size = scheduled_seq_group.token_chunk_size - seq_group.maybe_set_first_scheduled_time(now) - - seq_group_metadata = self._seq_group_metadata_cache[ - self.cache_id].get_object() - seq_group_metadata.seq_data.clear() - seq_group_metadata.block_tables.clear() - - # seq_id -> SequenceData - seq_data: dict[int, SequenceData] = {} - # seq_id -> physical block numbers - block_tables: dict[int, list[int]] = {} - - if seq_group.is_encoder_decoder(): - # Encoder associated with SequenceGroup - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None - encoder_seq_data = encoder_seq.data - # Block table for cross-attention - # Also managed at SequenceGroup level - cross_block_table = self.block_manager.get_cross_block_table( - seq_group) - else: - encoder_seq_data = None - cross_block_table = None - - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - seq_id = seq.seq_id - seq_data[seq_id] = seq.data - block_tables[seq_id] = self.block_manager.get_block_table(seq) - self.block_manager.access_all_blocks_in_seq(seq, now) - - if self.cache_config.enable_prefix_caching: - common_computed_block_nums = ( - self.block_manager.get_common_computed_block_ids( - seq_group.get_seqs(status=SequenceStatus.RUNNING))) - - do_sample = True - is_prompt = seq_group.is_prefill() - # We should send the metadata to workers when the first prefill - # is sent. Subsequent requests could be chunked prefill or decode. - is_first_prefill = False - if is_prompt: - seqs = seq_group.get_seqs() - # Prefill has only 1 sequence. - assert len(seqs) == 1 - num_computed_tokens = seqs[0].data.get_num_computed_tokens() - is_first_prefill = num_computed_tokens == 0 - # In the next iteration, all prompt tokens are not computed. - # It means the prefill is chunked, and we don't need sampling. - # NOTE: We use get_len instead of get_prompt_len because when - # a sequence is preempted, prefill includes previous generated - # output tokens. - if (token_chunk_size + num_computed_tokens - < seqs[0].data.get_len()): - do_sample = False - - # It assumes the scheduled_seq_groups is ordered by - # prefill < decoding. - if is_first_prefill or not self.scheduler_config.send_delta_data: - seq_group_metadata = SequenceGroupMetadata( - request_id=seq_group.request_id, - is_prompt=is_prompt, - seq_data=seq_data, - sampling_params=seq_group.sampling_params, - block_tables=block_tables, - do_sample=do_sample, - pooling_params=seq_group.pooling_params, - token_chunk_size=token_chunk_size, - lora_request=seq_group.lora_request, - computed_block_nums=common_computed_block_nums, - encoder_seq_data=encoder_seq_data, - cross_block_table=cross_block_table, - state=seq_group.state, - token_type_ids=seq_group.token_type_ids, - # `multi_modal_data` will only be present for the 1st comm - # between engine and worker. - # the subsequent comms can still use delta, but - # `multi_modal_data` will be None. - multi_modal_data=(seq_group.multi_modal_data - if scheduler_outputs.num_prefill_groups - > 0 else None), - multi_modal_placeholders=( - seq_group.multi_modal_placeholders - if scheduler_outputs.num_prefill_groups > 0 else None), - ) - else: - # When SPMD mode is enabled, we only send delta data except for - # the first request to reduce serialization cost. - seq_data_delta = {} - for id, data in seq_data.items(): - seq_data_delta[id] = data.get_delta_and_reset() - seq_group_metadata = SequenceGroupMetadataDelta( - seq_data_delta, - seq_group.request_id, - block_tables, - is_prompt, - do_sample=do_sample, - token_chunk_size=token_chunk_size, - computed_block_nums=common_computed_block_nums, - ) - seq_group_metadata_list.append(seq_group_metadata) - - if allow_async_output_proc: - allow_async_output_proc = self._allow_async_output_proc( - seq_group) - - # Now that the batch has been created, we can assume all blocks in the - # batch will have been computed before the next scheduling invocation. - # This is because the engine assumes that a failure in model execution - # will crash the vLLM instance / will not retry. - for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: - self.block_manager.mark_blocks_as_computed( - scheduled_seq_group.seq_group, - scheduled_seq_group.token_chunk_size) - - self._seq_group_metadata_cache[self.next_cache_id].reset() - - scheduler_time = time.perf_counter() - scheduler_start_time - # Add this to scheduler time to all the sequences that are currently - # running. This will help estimate if the scheduler is a significant - # component in the e2e latency. - for seq_group in self.running: - if seq_group is not None and seq_group.metrics is not None: - if seq_group.metrics.scheduler_time is not None: - seq_group.metrics.scheduler_time += scheduler_time - else: - seq_group.metrics.scheduler_time = scheduler_time - - # Move to next cache (if exists) - self.cache_id = self.next_cache_id - - # Return results - return (seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc) - - def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: - self.block_manager.fork(parent_seq, child_seq) - - def free_seq(self, seq: Sequence) -> None: - """Free a sequence from a block table.""" - self.block_manager.free(seq) - - def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: - """Free finished seqs in a sequence group.""" - for seq in seq_group.get_seqs(): - if seq.is_finished(): - self.free_seq(seq) - - def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None: - if seq_group.is_finished(): - # Free cross-attention block table, if it exists - self._free_seq_group_cross_attn_blocks(seq_group) - - # Add the finished requests to the finished requests list. - # This list will be used to update the Mamba cache in the - # next step. - self._finished_requests_ids.append(seq_group.request_id) - - # Free finished seqs - self._free_finished_seqs(seq_group) - - def free_finished_seq_groups(self) -> None: - remaining: deque[SequenceGroup] = deque() - for seq_group in self.running: - self._free_finished_seq_group(seq_group) - if not seq_group.is_finished(): - remaining.append(seq_group) - - self.running = remaining - - # Handle async stopped sequence groups - # (ones that reached max model len) - if self._async_stopped: - for seq_group in self._async_stopped: - self._free_seq_group_cross_attn_blocks(seq_group) - self._finished_requests_ids.append(seq_group.request_id) - - # Free finished seqs - self._free_finished_seqs(seq_group) - - self._async_stopped.clear() - - def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: - self.block_manager.allocate(seq_group) - for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): - seq.status = SequenceStatus.RUNNING - - def _append_slots( - self, - seq_group: SequenceGroup, - blocks_to_copy: list[tuple[int, int]], - enable_chunking: bool = False, - ) -> None: - """Appends new slots to the sequences in the given sequence group. - - Args: - seq_group (SequenceGroup): The sequence group containing the - sequences to append slots to. - blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two - ints, the first int is the source block index, and the second - int is the destination block index. This list is updated with - the new source and destination block indices for the appended - slots. - enable_chunking (bool): True if chunked prefill is enabled. - """ - is_prefill: bool = seq_group.is_prefill() - num_lookahead_slots: int = self._get_num_lookahead_slots( - is_prefill, enable_chunking) - - seq_group.init_multi_step_from_lookahead_slots( - num_lookahead_slots, - num_scheduler_steps=self.scheduler_config.num_scheduler_steps, - is_multi_step=self.scheduler_config.is_multi_step, - enable_chunking=enable_chunking, - ) - - seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING - if self.scheduler_config.is_multi_step and enable_chunking: - # In multi-step chunked-prefill any sequence type can have - # slots appended. - seq_status = None - - for seq in seq_group.get_seqs(status=seq_status): - cows = self.block_manager.append_slots(seq, num_lookahead_slots) - if len(cows) > 0: - blocks_to_copy.extend(cows) - - def _preempt(self, seq_group: SequenceGroup, - blocks_to_swap_out: list[tuple[int, int]]) -> PreemptionMode: - # If preemption mode is not specified, we determine the mode as follows: - # We use recomputation by default since it incurs lower overhead than - # swapping. However, when the sequence group has multiple sequences - # (e.g., beam search), recomputation is not currently supported. In - # such a case, we use swapping instead. - # FIXME(woosuk): This makes our scheduling policy a bit bizarre. - # As swapped sequences are prioritized over waiting sequences, - # sequence groups with multiple sequences are implicitly prioritized - # over sequence groups with a single sequence. - # TODO(woosuk): Support recomputation for sequence groups with multiple - # sequences. This may require a more sophisticated CUDA kernel. - if self.user_specified_preemption_mode is None: - if seq_group.get_max_num_running_seqs() == 1: - preemption_mode = PreemptionMode.RECOMPUTE - else: - preemption_mode = PreemptionMode.SWAP - - elif self.user_specified_preemption_mode == "swap": - preemption_mode = PreemptionMode.SWAP - else: - preemption_mode = PreemptionMode.RECOMPUTE - - if self.num_cumulative_preemption % 50 == 0: - logger.warning( - "Sequence group %s is preempted by %s mode because there is " - "not enough KV cache space. This can affect the end-to-end " - "performance. Increase gpu_memory_utilization or " - "tensor_parallel_size to provide more KV cache memory. " - "total_num_cumulative_preemption=%d", - seq_group.request_id, - preemption_mode, - self.num_cumulative_preemption + 1, - ) - self.num_cumulative_preemption += 1 - - if preemption_mode == PreemptionMode.RECOMPUTE: - self._preempt_by_recompute(seq_group) - elif preemption_mode == PreemptionMode.SWAP: - self._preempt_by_swap(seq_group, blocks_to_swap_out) - else: - raise AssertionError("Invalid preemption mode.") - return preemption_mode - - def _preempt_by_recompute( - self, - seq_group: SequenceGroup, - ) -> None: - seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - assert len(seqs) == 1 - for seq in seqs: - seq.status = SequenceStatus.WAITING - self.free_seq(seq) - seq.reset_state_for_recompute() - self._free_seq_group_cross_attn_blocks(seq_group) - - def _preempt_by_swap( - self, - seq_group: SequenceGroup, - blocks_to_swap_out: list[tuple[int, int]], - ) -> None: - self._swap_out(seq_group, blocks_to_swap_out) - - def _swap_in( - self, - seq_group: SequenceGroup, - blocks_to_swap_in: list[tuple[int, int]], - ) -> None: - mapping = self.block_manager.swap_in(seq_group) - blocks_to_swap_in.extend(mapping) - for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - seq.status = SequenceStatus.RUNNING - - def _swap_out( - self, - seq_group: SequenceGroup, - blocks_to_swap_out: list[tuple[int, int]], - ) -> None: - if not self.block_manager.can_swap_out(seq_group): - # FIXME(woosuk): Abort the sequence group instead of aborting the - # entire engine. - raise RuntimeError( - "Aborted due to the lack of CPU swap space. Please increase " - "the swap space to avoid this error.") - mapping = self.block_manager.swap_out(seq_group) - blocks_to_swap_out.extend(mapping) - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - seq.status = SequenceStatus.SWAPPED - - def _passed_delay(self, now: float) -> bool: - if self.prev_prompt: - self.last_prompt_latency = now - self.prev_time - self.prev_time, self.prev_prompt = now, False - # Delay scheduling prompts to let waiting queue fill up - if self.scheduler_config.delay_factor > 0 and self.waiting: - earliest_arrival_time = min( - [e.metrics.arrival_time for e in self.waiting]) - passed_delay = ((now - earliest_arrival_time) - > (self.scheduler_config.delay_factor * - self.last_prompt_latency) or not self.running) - else: - passed_delay = True - return passed_delay - - def _get_num_lookahead_slots(self, is_prefill: bool, - enable_chunking: bool) -> int: - """The number of slots to allocate per sequence per step, beyond known - token ids. Speculative decoding uses these slots to store KV activations - of tokens which may or may not be accepted. - - Speculative decoding does not yet support prefill, so we do not perform - lookahead allocation for prefill. - - When chunking is enabled with multi-step, we allocate lookahead slots - for the prefills for when the prefills turn into decodes in the first - step. - """ - if is_prefill: - if self.scheduler_config.is_multi_step and enable_chunking: - # num_lookahead_slots was introduced in the context of decodes, - # in Speculative Decoding. - # When the num_scheduler_steps is 8, say, then the - # num_lookahead_slots is 7. Meaning, we are doing a 1-step of - # decode anyways and we wish to do 7 more. - # - # "lookaheads" for prefills, is introduced in support for - # Chunked-Prefill in Multi-Step. - return self.scheduler_config.num_lookahead_slots + 1 - else: - return 0 - - return self.scheduler_config.num_lookahead_slots - - def _get_num_new_uncached_and_cached_tokens( - self, - seq_group: SequenceGroup, - status: SequenceStatus, - enable_chunking: bool, - budget: SchedulingBudget, - partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, - ) -> tuple[int, int]: - """ - Returns the number of new uncached and cached tokens to schedule for a - given sequence group that's in a given `status`. - - The API could chunk the number of tokens to compute based on `budget` - if `enable_chunking` is True. If a sequence group has multiple - sequences (e.g., running beam search), it means it is in decoding - phase, so chunking doesn't happen. - - Returns (0, 0) if the new token cannot be computed due to token budget. - - The cached tokens's blocks are already computed, and the attention - backend will reuse the cached blocks rather than recomputing them. So - the scheduler could schedule these cached tokens "for free". - - Args: - seq_group: The sequence group to get the number of new tokens to - schedule. - status: The status of the sequences to get the number of new tokens - to schedule. - enable_chunking: Whether to chunk the number of tokens to compute. - budget: The budget to chunk the number of tokens to compute. - partial_prefill_metadata: information about the partial prefills - that are currently running - - - Returns: - A tuple of two ints. The first int is the number of new uncached - tokens to schedule. The second int is the number of cached tokens. - If no more new tokens can be scheduled, returns (0, 0). - """ - num_cached_new_tokens = 0 - num_uncached_new_tokens = 0 - - seqs = seq_group.get_seqs(status=status) - # Compute the number of new uncached and cached tokens for - # each sequence. - for seq in seqs: - if not seq.is_prefill(): - # Decode sequences should always just have 1 uncached token - # TODO(rickyx): Actually is this still correct for multi-step? - num_uncached_new_tokens += 1 - continue - - num_computed_tokens_seq = seq.get_num_computed_tokens() - all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq - if not self.cache_config.enable_prefix_caching: - # If prefix caching is not enabled, all new tokens are uncached. - num_uncached_new_tokens += all_num_new_tokens_seq - continue - - # NOTE: the cache token might be currently in a block that's in an - # evictor meaning that it's not yet allocated. However, we don't - # exclude such tokens in the cache count because it will be - # guaranteed to be allocated later if the sequence can be allocated. - num_cached_tokens_seq = self.block_manager.get_num_cached_tokens( - seq) - - # Sanity check. - if num_cached_tokens_seq < num_computed_tokens_seq: - # This should only happen with chunked prefill, and - # the seq is still in prefill. The `num_cached_tokens_seq` - # is the value we calculated on scheduling the first prefill. - # For subsequent continuous prefill steps, we cached the - # number of cache tokens for the sequence so the cached token - # count could be less than the number of computed tokens. - # See comments on `ComputedBlocksTracker` for more details. - assert ( - seq.is_prefill() and seq.status == SequenceStatus.RUNNING - and self.scheduler_config.chunked_prefill_enabled - ), ("Number of cached tokens should not be less than the " - "number of computed tokens for a sequence that's still " - f"in prefill. But there are {num_cached_tokens_seq} cached " - f"tokens and {num_computed_tokens_seq} computed tokens " - f"for sequence {seq.seq_id}.") - - num_cached_new_tokens_seq = max( - 0, num_cached_tokens_seq - num_computed_tokens_seq) - num_uncached_new_tokens_seq = (all_num_new_tokens_seq - - num_cached_new_tokens_seq) - - num_uncached_new_tokens += num_uncached_new_tokens_seq - num_cached_new_tokens += num_cached_new_tokens_seq - - if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0: - # For a fully cached hit sequence, we actually need to recompute the - # last token. So we need at least 1 uncached token to schedule. - # See ModelRunner._compute_for_prefix_cache_hit for more details. - num_uncached_new_tokens = 1 - num_cached_new_tokens -= 1 - - if enable_chunking and len(seqs) == 1: - # Chunk if a running request cannot fit in the given budget. - # If number of seq > 1, it means it is doing beam search - # in a decode phase. Do not chunk. - num_uncached_new_tokens = self._chunk_new_tokens_to_schedule( - self.scheduler_config, - self.cache_config, - budget, - self._get_prompt_limit(seq_group), - num_uncached_new_tokens, - self.partial_prefill_budget_lookup_list, - partial_prefill_metadata, - ) - - return num_uncached_new_tokens, num_cached_new_tokens - - @staticmethod - def _chunk_new_tokens_to_schedule( - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - budget: SchedulingBudget, - prompt_limit: int, - num_new_tokens: int, - partial_prefill_budget_lookup_list: list[int], - partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, - ) -> int: - """ - Chunks the number of new tokens to schedule based on the budget when - chunked prefill is enabled. - - Args: - scheduler_config: The scheduler config. - cache_config: The cache config. - budget: The budget to chunk the number of tokens to compute. - prompt_limit: The maximum number of tokens allowed in a prompt. - num_new_tokens: The number of new tokens to schedule. - - Returns: - The number of new tokens to schedule after chunking. - """ - remaining_token_budget = budget.remaining_token_budget() - if scheduler_config.is_multi_step: - # The current multi-step + chunked prefill capability does - # not actually support chunking prompts. - # - # Therefore, `num_new_tokens` is computed in the same fashion - # for both multi-step+chunked-prefill & - # multi-step+chunked-prefill+APC - # - # Prompts with more tokens than the current remaining budget - # are postponed to future scheduler steps - if num_new_tokens > prompt_limit: - # If the seq_group is in prompt-stage, pass the - # num_new_tokens as-is so the caller can ignore - # the sequence. - return num_new_tokens - - return 0 if num_new_tokens > \ - remaining_token_budget else num_new_tokens - - # Get the number of tokens to allocate to this prefill slot - prefill_slot_budget = ( - remaining_token_budget if partial_prefill_metadata is None else - partial_prefill_budget_lookup_list[ - partial_prefill_metadata.schedulable_prefills]) - - if cache_config.enable_prefix_caching: - # When prefix caching is enabled and we're partially prefilling - # a sequence, we always allocate a number of new tokens that is - # divisible by the block size to avoid partial block matching. - block_size = cache_config.block_size - # Don't exceed either the total budget or slot budget. - # Take min of those and get the next lowest multiple of the - # block size: - remaining_token_budget = ( - min(remaining_token_budget, prefill_slot_budget) // - block_size) * block_size - # NB: In the case where num_new_tokens < budget, we are - # finishing prefill for this sequence, so we do not need to - # allocate a full block. - - num_new_tokens = min(num_new_tokens, remaining_token_budget, - prefill_slot_budget) - - return num_new_tokens diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 16624d08f..6161cce00 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -27,7 +27,6 @@ else: ModelConfig = None VllmConfig = None -import vllm.envs as envs from vllm.platforms import Platform, PlatformEnum import vllm_spyre.envs as envs_spyre @@ -121,16 +120,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if model_config.task \ else "embed" in model_config.supported_tasks - if is_decoder and not envs.VLLM_USE_V1: - raise ValueError("Decoder models are only supported on v1") + if not bool(int(os.getenv("VLLM_USE_V1", "1"))): + raise ValueError("vllm-spyre is only supported with vLLM v1. " + "Please set VLLM_USE_V1=1") elif not is_decoder and not is_pooling: raise ValueError("Only the 'generate' and 'embed' tasks are " "supported") if parallel_config.worker_cls == "auto": - parallel_config.worker_cls = ( - f'vllm_spyre{".v1" if envs.VLLM_USE_V1 else ""}'\ - '.worker.spyre_worker.SpyreWorker') + parallel_config.worker_cls = "vllm_spyre.v1.worker."\ + "spyre_worker.SpyreWorker" cls._check_threading_config(parallel_config.world_size) @@ -163,18 +162,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: model_config.max_model_len = max_seq_len scheduler_config.max_num_seqs = max_batch_size - if is_decoder: - scheduler_config.scheduler_cls = ( + scheduler_config.scheduler_cls = ( "vllm_spyre.v1.core.scheduler."\ "StaticBatchingSpyreScheduler") - elif is_pooling: - if not envs.VLLM_USE_V1: - scheduler_config.scheduler_cls = ( - "vllm_spyre.core.scheduler.SpyreScheduler") - else: - scheduler_config.scheduler_cls = ( - "vllm_spyre.v1.core.scheduler."\ - "StaticBatchingSpyreScheduler") # To disable any paged attention ops in the base scheduler, we: # - Set the block size (in tokens) to the maximum sequence length @@ -186,17 +176,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # length requests, so that the scheduler will always have token # budget available to schedule a full batch if cache_config is not None: - if envs.VLLM_USE_V1: - # overriding number of available Spyre blocks if not None - if cache_config.num_gpu_blocks_override: - cls._num_spyre_blocks_override = \ - cache_config.num_gpu_blocks_override - # The V1 scheduler actually needs 2 blocks for each sequence... - cache_config.num_gpu_blocks_override = \ - scheduler_config.max_num_seqs * 2 - else: - cache_config.num_gpu_blocks_override = \ - scheduler_config.max_num_seqs + # overriding number of available Spyre blocks if not None + if cache_config.num_gpu_blocks_override: + cls._num_spyre_blocks_override = \ + cache_config.num_gpu_blocks_override + # The V1 scheduler actually needs 2 blocks for each sequence... + cache_config.num_gpu_blocks_override = \ + scheduler_config.max_num_seqs * 2 cache_config.block_size = model_config.max_model_len scheduler_config.max_num_batched_tokens = ( diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index fbf6a3f17..797868a81 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -660,7 +660,7 @@ def execute_model( # handler from vLLM when it starts a process for the engine code. Therefore, # the engine does not have a chance to gracefully shutdown. def maybe_override_signals_handler(): - if not (envs.VLLM_USE_V1 and envs.VLLM_ENABLE_V1_MULTIPROCESSING + if not (envs.VLLM_ENABLE_V1_MULTIPROCESSING and envs_spyre.VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER): return diff --git a/vllm_spyre/worker/__init__.py b/vllm_spyre/worker/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/vllm_spyre/worker/spyre_embedding_model_runner.py b/vllm_spyre/worker/spyre_embedding_model_runner.py deleted file mode 100644 index 3be4e67c3..000000000 --- a/vllm_spyre/worker/spyre_embedding_model_runner.py +++ /dev/null @@ -1,204 +0,0 @@ -import time -from collections.abc import Iterable -from typing import Optional - -import torch -from transformers import AutoModel -from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, - SchedulerConfig) -from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.pooling_params import PoolingParams -from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, - SequenceGroupMetadata) - -import vllm_spyre.envs as envs_spyre - -from .spyre_model_runner import ModelInputForSpyre, SpyreModelRunner - -logger = init_logger(__name__) - -BACKEND_LIST = ['sendnn', 'inductor'] - - -class SpyreEmbeddingModelRunner(SpyreModelRunner): - - # Map of request_id -> generator used for seeded random sampling - generators: dict[str, torch.Generator] = {} - - def __init__( - self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - is_driver_worker: bool, - ): - super().__init__(model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - is_driver_worker=is_driver_worker) - - pooler_config = model_config.pooler_config - if hasattr(Pooler, "from_config_with_defaults"): - # TODO: remove this when we no longer support - # vllm version v0.9.2 - self.pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.CLS, - normalize=True, - softmax=False) - else: - self.pooler = Pooler.for_embed( - pooler_config=pooler_config, - default_pooling_type=PoolingType.CLS, - default_normalize=True, - default_softmax=False) - - def load_model(self, prompt_lens: Iterable[int], - num_decode_tokens: Iterable[int]) -> None: - self.model = AutoModel.from_pretrained(self.model_config.model) - self.model.eval() - torch.set_grad_enabled(False) - if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST: - self.model = torch.compile( - self.model, - mode="default", - dynamic=False, - backend=envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND) - - @property - def vocab_size(self) -> int: - return self.model.config.vocab_size - - def prepare_input_tensors( - self, - seq_group_metadata_list: list[SequenceGroupMetadata], - finished_requests_ids: Optional[list[str]] = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, PoolingMetadata]: - # NOTE: We assume that all sequences in the group are all prompts - (input_tokens, input_positions, input_masks, - seq_lens) = self._prepare_prompt(seq_group_metadata_list) - - pooling_metadata = self._prepare_pooling( - seq_group_metadata_list=seq_group_metadata_list, - prompt_lens=seq_lens) - return (input_tokens, input_positions, input_masks, pooling_metadata) - - def _prepare_pooling( - self, - seq_group_metadata_list: list[SequenceGroupMetadata], - prompt_lens: list[int], - ) -> PoolingMetadata: - """Prepare PoolingMetadata for the sequence group metadata list.""" - seq_groups: list[tuple[list[int], PoolingParams]] = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - pooling_params = seq_group_metadata.pooling_params - seq_groups.append((seq_ids, pooling_params)) - - seq_data: dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - pooling_metadata = PoolingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - ) - - return pooling_metadata - - def pad_input_ids( - self, - input_ids_list: list[torch.Tensor], - min_pad_length: int = 0, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - padded_input_ids_list, mask_list, position_ids_list = self.\ - _prepare_pad_input_ids(input_ids_list, min_pad_length) - - input_ids = torch.stack(padded_input_ids_list) - mask = torch.stack(mask_list) - position_ids = torch.stack(position_ids_list) - - return input_ids, position_ids, mask - - def prepare_model_input( - self, - seq_group_metadata_list: list[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[list[str]] = None - ) -> ModelInputForSpyre: - - (input_tokens, input_positions, input_masks, - pooling_metadata) = self.prepare_input_tensors( - seq_group_metadata_list, finished_requests_ids) - - return ModelInputForSpyre(input_tokens=input_tokens, - input_positions=input_positions, - input_masks=input_masks, - pooling_metadata=pooling_metadata) - - def execute_model( - self, - model_input: ModelInputForSpyre, - kv_caches: Optional[list[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - **kwargs, - ) -> Optional[list[PoolerOutput]]: - - t0 = time.time() - - outputs = self.model( - input_ids=model_input.input_tokens, - # Let the Embedding layer use it's default - # because the rules can be a bit different - # e.g. For Roberta models the inputs start - # at padding_inx +1 - #position_ids=input_positions, - attention_mask=model_input.input_masks) - hidden_states = outputs["last_hidden_state"] - - unpadded = [] - max_len = hidden_states.shape[1] - - if model_input.pooling_metadata is not None: - for i, seq_len in enumerate( - model_input.pooling_metadata.prompt_lens): - unpadded.append(hidden_states[i, max_len - seq_len:, :]) - - hidden_states = torch.concat(unpadded) - - pooler_output = self.pooler( - hidden_states=hidden_states, - pooling_metadata=model_input.pooling_metadata) - - t1 = time.time() - t0 - print("[spyre_model_runner:execute_model] t_token: %.2fms" % - (t1 * 1000)) - - return [pooler_output] - - def _raw_model_forward( - self, - input_ids: torch.Tensor, - mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_value_states: Optional[list[tuple[torch.Tensor, - torch.Tensor]]] = None, - use_cache: bool = False, - only_last_token: bool = False, - attn_algorithm: Optional[str] = None - ) -> tuple[torch.Tensor, Optional[list[tuple[torch.Tensor, - torch.Tensor]]]]: - - hidden_states, _ = self.model( - input_ids=input_ids, - attention_mask=mask, - #position_ids=position_ids - ) - return hidden_states, None diff --git a/vllm_spyre/worker/spyre_model_runner.py b/vllm_spyre/worker/spyre_model_runner.py deleted file mode 100644 index 104b67310..000000000 --- a/vllm_spyre/worker/spyre_model_runner.py +++ /dev/null @@ -1,439 +0,0 @@ -import time -from collections.abc import Iterable -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, TypeVar - -import torch -from torch import nn -from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, - SchedulerConfig) -from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import is_pin_memory_available -from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, - _add_sampling_metadata_broadcastable_dict, - _init_sampling_metadata_from_tensor_dict) - -from vllm_spyre.model_executor.model_loader.spyre import SpyreCausalLM -from vllm_spyre.platform import SpyrePlatform - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - from vllm.model_executor.pooling_metadata import PoolingMetadata - -logger = init_logger(__name__) - -TModelInputForSpyre = TypeVar('TModelInputForSpyre', - bound="ModelInputForSpyre") - - -@dataclass(frozen=True) -class ModelInputForSpyre(ModelRunnerInputBase): - """ - Used by the SpyreModelRunner. - """ - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - input_masks: Optional[torch.Tensor] = None - sampling_metadata: Optional[SamplingMetadata] = None - pooling_metadata: Optional["PoolingMetadata"] = None - is_prompt: Optional[bool] = None - # unused - virtual_engine: Optional[int] = None - - def as_broadcastable_tensor_dict(self) -> dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - "input_masks": self.input_masks, - "is_prompt": self.is_prompt, - } - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls: type[TModelInputForSpyre], - tensor_dict: dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> TModelInputForSpyre: - tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) - return cls(**tensor_dict) - - -class SpyreModelRunner(ModelRunnerBase[ModelInputForSpyre]): - - def __init__( - self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - is_driver_worker: bool, - ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.is_driver_worker = is_driver_worker - self.spyre_warmup_shapes = SpyrePlatform.get_warmup_shapes( - self.scheduler_config) - - self.pad_token_id = 0 - if model_config is not None: - if model_config.hf_config is not None: - self.pad_token_id = getattr(model_config.hf_config, - "pad_token_id", None) or 0 - if model_config.get_sliding_window(): - logger.warning("Sliding window is not supported on Spyre. " - "The model will run without sliding window.") - self.device_config = (device_config - if device_config is not None else DeviceConfig()) - self.device = self.device_config.device - self.pin_memory = is_pin_memory_available() - # position_ids of all the sequences in current batch - self._position_ids: torch.Tensor = None - # attention masks of all the sequences in current batch - self._mask: torch.Tensor = None - # mapping: request id to index in batch - self._req_ids2idx: dict = {} - # Lazy initialization: after load_model. - self.model: nn.Module - - def get_model(self) -> nn.Module: - return self.model - - def load_model(self, prompt_lens: Iterable[int], - num_decode_tokens: Iterable[int]) -> None: - max_pad_length = max(prompt_lens) - max_decode_length = max(num_decode_tokens) - self.model = SpyreCausalLM( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - max_prompt_length=max_pad_length, - max_decode_length=max_decode_length, - ) - - @property - def vocab_size(self) -> int: - return self.model.model.model.config.src_vocab_size - - def _prepare_prompt( - self, - seq_group_metadata_list: list[SequenceGroupMetadata], - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[int]]: - assert len(seq_group_metadata_list) > 0 - input_token_list: list[torch.Tensor] = [] - - # find warmup shape to be used for padding and batching - applicable_spyre_warmup_shapes = [ - shape for shape in self.spyre_warmup_shapes - if len(seq_group_metadata_list) <= shape['batch_size'] - ] - for seq_group_metadata in seq_group_metadata_list: - seq_data = seq_group_metadata.seq_data[list( - seq_group_metadata.seq_data.keys())[0]] - # retrieve initial (unpadded) tokens - prompt_tokens = seq_data.get_token_ids() - new_tokens = seq_group_metadata.sampling_params.max_tokens\ - if seq_group_metadata.sampling_params is not None else 0 - - updated_spyre_warmup_shapes = [ - shape for shape in applicable_spyre_warmup_shapes - if len(prompt_tokens) <= shape['prompt_length'] - and new_tokens <= shape['new_tokens'] - ] - applicable_spyre_warmup_shapes = updated_spyre_warmup_shapes - - assert applicable_spyre_warmup_shapes - - # If multiple warmup shapes apply, the first one is selected. - # For improving performance, the warmup shapes in scheduler_config - # are ordered by "processing speed". - min_pad_length_batch = applicable_spyre_warmup_shapes[0][ - 'prompt_length'] - padded_batch_size = applicable_spyre_warmup_shapes[0]['batch_size'] - - self._req_ids2idx = {} - for idx, seq_group_metadata in enumerate(seq_group_metadata_list): - assert seq_group_metadata.is_prompt - self._req_ids2idx[seq_group_metadata.request_id] = idx - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - # retrieve initial (unpadded) tokens - prompt_tokens = seq_data.get_token_ids() - - input_token_list.append( - torch.tensor(prompt_tokens, - dtype=torch.long, - device=torch.device("cpu"))) - - actual_batch_size = len(input_token_list) - self.model.indices = torch.cat([ - torch.ones(actual_batch_size, dtype=torch.bool, device='cpu'), - torch.zeros(padded_batch_size - actual_batch_size, - dtype=torch.bool, - device='cpu') - ]) - - # padding to compiled batch size - while len(input_token_list) < padded_batch_size: - input_token_list.append( - torch.zeros(min_pad_length_batch, - dtype=torch.long, - device=torch.device("cpu"))) - - # get position ids and attention mask - input_tokens, self._position_ids, self._mask = self.pad_input_ids( - input_token_list, min_pad_length=min_pad_length_batch) - - seq_lens = [t.shape[0] for t in input_token_list] - - return input_tokens, self._position_ids, self._mask, seq_lens - - def _prepare_decode( - self, - seq_group_metadata_list: list[SequenceGroupMetadata], - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert len(seq_group_metadata_list) > 0 - input_tokens: list[list[int]] = [ - [0] for _ in range(self._position_ids.shape[0]) - ] - - for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens[self._req_ids2idx[seq_group_metadata.request_id]] = [ - generation_token - ] - - # update position ids and attention mask - self._update_position_ids() - self._update_mask() - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - - return input_tokens, self._position_ids, self._mask - - def _update_position_ids(self) -> None: - """Updating the position ids of all sequences - in a batch. Will be called in decoding phase""" - - self._position_ids = self._position_ids[:, -1] + 1 - self._position_ids = self._position_ids.unsqueeze(-1) - - def _update_mask(self) -> None: - """Updating/extending the attention masks of all - sequences in a batch. Will be called in decoding phase""" - - assert self._mask is not None - masks = self._mask - - masks_new = [] - for mask in masks: - # get the last row of the 3d mask - mask_new = mask[-1:, :] - - # extend the mask one slot - mask_new = torch.cat( - ( - mask_new, - torch.zeros( - 1, 1, dtype=mask_new.dtype, device=mask_new.device), - ), - dim=1, - ) - masks_new.append(mask_new) - - self._mask = torch.stack(masks_new, dim=0) - - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: dict[str, Any]) -> ModelInputForSpyre: - return ModelInputForSpyre.from_broadcasted_tensor_dict(tensor_dict) - - def prepare_model_input( - self, - seq_group_metadata_list: list[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[list[str]] = None - ) -> ModelInputForSpyre: - - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, input_masks, - _) = self._prepare_prompt(seq_group_metadata_list) - seq_lens = [ - input_tokens.shape[1] for i in range(input_tokens.shape[0]) - ] - else: - # updating indices: set indices of newly finished sequences False - if finished_requests_ids: - for seq_id in finished_requests_ids: - # ignore requests that are not in the batch, eg. requests - # cancelled while waiting - if (idx := self._req_ids2idx.get(seq_id)) is not None: - self.model.indices[idx] = False - (input_tokens, input_positions, - input_masks) = self._prepare_decode(seq_group_metadata_list) - seq_lens = [] - - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query_lens is not needed if chunked prefill is not - # supported. Since Spyre worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, - self.device, - self.pin_memory, - self.get_generators(finished_requests_ids)) - - return ModelInputForSpyre(input_tokens=input_tokens, - input_positions=input_positions, - input_masks=input_masks, - sampling_metadata=sampling_metadata, - is_prompt=is_prompt) - - def execute_model( - self, - model_input: ModelInputForSpyre, - kv_caches: Optional[list[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - **kwargs, - ) -> Optional[list[SamplerOutput]]: - - t0 = time.time() - - if num_steps > 1: - raise ValueError( - "SpyreModelRunner does not support multi-step execution.") - - hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - masks=model_input.input_masks, - is_prompt=model_input.is_prompt, - ) - - # Only perform sampling in the driver worker. - if not self.is_driver_worker: - return [] - - # Compute the logits. - logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata) - - # Sample the next token. - output = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - t1 = time.time() - t0 - print("[spyre_model_runner:execute_model] t_token: %.2fms" % - (t1 * 1000)) - - return [output] - - def _prepare_pad_input_ids( - self, - input_ids_list: list[torch.Tensor], - min_pad_length: int = 0, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """left side padding implemented as - in fms.utils.generation.pad_input_id""" - max_len = max([min_pad_length] + - [seq.size(0) for seq in input_ids_list]) - padded_input_ids_list = [] - mask_list = [] - position_ids_list = [] - for input_ids_i in input_ids_list: - seq_len = input_ids_i.size(0) - if max_len > seq_len: - print(f"[SpyreModelRunner] INFO: Padding request of length " - f"{seq_len} tokens to {max_len} tokens.") - pads = torch.ones(max_len - seq_len, - dtype=torch.long, - device=input_ids_i.device) * self.pad_token_id - non_pads = torch.ones(seq_len, - dtype=torch.long, - device=input_ids_i.device) - - pos_ids_pads = pads - pos_ids_seq = torch.arange(0, - seq_len, - dtype=torch.long, - device=input_ids_i.device) - - # Setting this to 0, however if 0 is the eos, we will end up - # truncating the output if using truncate_after_eos once this - # workflow works for nested tensor, this can probably be removed - padded_input_ids_list.append(torch.cat((pads, input_ids_i))) - mask_list.append(torch.cat((torch.zeros_like(pads), non_pads))) - position_ids_list.append(torch.cat((pos_ids_pads, pos_ids_seq))) - - return padded_input_ids_list, mask_list, position_ids_list - - def pad_input_ids( - self, - input_ids_list: list[torch.Tensor], - min_pad_length: int = 0, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - padded_input_ids_list, mask_list, position_ids_list = self.\ - _prepare_pad_input_ids(input_ids_list, min_pad_length) - - input_ids = torch.stack(padded_input_ids_list) - mask = torch.stack(mask_list).bool() - # this is a causal mask for generation - mask = (mask.unsqueeze(-1) == mask.unsqueeze(-2)).tril() - mask = torch.where(mask.logical_not(), -torch.inf, 0.0) - mask = mask.to(self.model.model.dtype) - position_ids = torch.stack(position_ids_list) - - return input_ids, position_ids, mask - - def _raw_model_forward( - self, - input_ids: torch.Tensor, - mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_value_states: Optional[list[tuple[torch.Tensor, - torch.Tensor]]] = None, - use_cache: bool = False, - only_last_token: bool = False, - attn_algorithm: Optional[str] = None - ) -> tuple[torch.Tensor, Optional[list[tuple[torch.Tensor, - torch.Tensor]]]]: - - # For AIU, we should not assign attn_algorithm - extra_kargs = {} - if attn_algorithm is not None: - extra_kargs['attn_algorithm'] = attn_algorithm - - return self.model.model.model( - input_ids, - mask=mask, - position_ids=position_ids, - past_key_value_states=past_key_value_states, - use_cache=use_cache, - only_last_token=only_last_token, - **extra_kargs) diff --git a/vllm_spyre/worker/spyre_worker.py b/vllm_spyre/worker/spyre_worker.py deleted file mode 100644 index 6c0be016a..000000000 --- a/vllm_spyre/worker/spyre_worker.py +++ /dev/null @@ -1,401 +0,0 @@ -"""A Spyre worker class.""" -import json -import os -import platform -import time -from typing import Optional - -import torch -import torch.distributed as dist -import vllm.envs as envs -from huggingface_hub import hf_hub_download -from vllm.config import VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, - WorkerInput) - -import vllm_spyre.envs as envs_spyre -import vllm_spyre.perf_metrics as perf_metrics -from vllm_spyre.model_executor.model_loader import spyre_setup -from vllm_spyre.platform import SpyrePlatform -from vllm_spyre.v1.worker.spyre_worker import _maybe_warmup_context -from vllm_spyre.worker.spyre_embedding_model_runner import ( - SpyreEmbeddingModelRunner) -# from vllm.worker.spyre_model_runner import SpyreModelRunner -from vllm_spyre.worker.spyre_model_runner import SpyreModelRunner - -# Post 0.7.3 this class was renamed -# yapf conflicts with isort for this block -# yapf: disable -try: - from vllm.worker.worker_base import LoRANotSupportedWorkerBase -except ImportError: - from vllm.worker.worker_base import ( - LoraNotSupportedWorkerBase as LoRANotSupportedWorkerBase) -# yapf: enable - - -class SpyreWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): - """A worker class that executes the model on a group of Spyre cores. - """ - - @property - def is_pooling(self) -> bool: - return self.model_config.task == "embed" \ - if self.model_config.task else \ - "embed" in self.model_config.supported_tasks - - @property - def is_decoder(self) -> bool: - return self.model_config.task == "generate" \ - if self.model_config.task else \ - "generate" in self.model_config.supported_tasks - - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - ) -> None: - WorkerBase.__init__(self, vllm_config=vllm_config) - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - self.is_driver_worker = is_driver_worker - self.perf_metrics = perf_metrics.create_perf_metric_logger(rank) - if self.parallel_config and is_driver_worker: - assert rank % self.parallel_config.tensor_parallel_size == 0, \ - "Driver worker should be rank 0 of tensor parallel group." - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - if self.is_pooling: - self.model_runner: SpyreModelRunner = SpyreEmbeddingModelRunner( - self.model_config, self.parallel_config, self.scheduler_config, - self.device_config, self.is_driver_worker) - else: - self.model_runner = SpyreModelRunner(self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - self.is_driver_worker) - self._env_initialized = False - self.spyre_warmup_shapes = SpyrePlatform.get_warmup_shapes( - self.scheduler_config) - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - activities = [torch.profiler.ProfilerActivity.CPU] - - if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn": - from torch_sendnn import torch_sendnn - torch.utils.rename_privateuse1_backend("aiu") - torch._register_device_module("aiu", - torch_sendnn.sendnn_backend) - torch.utils.generate_methods_for_privateuse1_backend() - activities.append(torch.profiler.ProfilerActivity.PrivateUse1) - - self.profiler = torch.profiler.profile( - activities=activities, - record_shapes=True, - with_stack=True, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) - print( - "[SpyreWorker] Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) - else: - self.profiler = None - - def start_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.start() - - def stop_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.stop() - - def init_distributed_environment(self) -> None: - """Initialize the distributed environment.""" - - torch._C._distributed_c10d._register_process_group( - "default", dist.group.WORLD) - - if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn": - spyre_setup.spyre_dist_setup( - rank=self.rank, - world_size=self.parallel_config.world_size, - verbose=True) - - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cpu()) - - def init_device(self) -> None: - - if platform.machine() == "s390x": - from torch.serialization import LoadEndianness - torch.serialization.set_default_load_endianness( - LoadEndianness.LITTLE) - - if not self._env_initialized: - - init_distributed_environment( - world_size=self.parallel_config.world_size, - rank=self.rank, - distributed_init_method="env://", - backend="gloo", - ) - - if self.parallel_config.world_size > 1: - self.init_distributed_environment() - elif envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn": - spyre_setup.spyre_setup() - - ensure_model_parallel_initialized( - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - ) - - self._env_initialized = True - - # Set random seed. - set_random_seed(self.model_config.seed) - - def load_model(self): - assert self._env_initialized - - is_local = os.path.isdir(self.model_config.model) - if is_local: - cf_file = os.path.join(self.model_config.model, 'config.json') - else: - cf_file = hf_hub_download(repo_id=self.model_config.model, - revision=self.model_config.revision, - filename="config.json") - with open(cf_file, 'rb') as f: - config = json.load(f) - - restricted_tokens = [] - if tok := config.get("bos_token_id") is not None: - restricted_tokens.append(int(tok)) - if tok := config.get("eos_token_id") is not None: - restricted_tokens.append(int(tok)) - - print("[SpyreWorker] load model...") - # TODO: check additionally if the Spyre card has enough memory - # for all requested model warmups - # printing env variables for debugging purposes - load_model_start_t = time.time() - wup_prompt_lens, wup_new_tokens = zip( - *[(s["prompt_length"], s["new_tokens"]) - for s in self.spyre_warmup_shapes]) - - self.model_runner.load_model(prompt_lens=wup_prompt_lens, - num_decode_tokens=wup_new_tokens) - - load_model_end_t = time.time() - load_model_total_t = load_model_end_t - load_model_start_t - self.perf_metrics.log("load model time", - load_model_total_t, - model=self.model_config.model) - print(f"\tload model took {load_model_total_t}s") - - print(f"[SpyreWorker] Start warming up " - f"{len(wup_new_tokens)} " - f"different prompt/decode/batchsize-shape combinations.") - all_warmup_start_t = time.time() - for i, (prompt_len, num_decode_tokens, batch_size) in enumerate([ - (s["prompt_length"], s["new_tokens"], s["batch_size"]) - for s in self.spyre_warmup_shapes - ]): - if not self.is_pooling: - # TODO: remove if spyre supports - # lower number of output tokens - assert num_decode_tokens >= 2, ( - "VLLM_SPYRE_WARMUP_NEW_TOKENS must be " - "at least 2 (spyre requirement).") - # warmup individual combination - print(f"[SpyreWorker] Warmup {i+1}/" - f"{len(wup_new_tokens)} " - f"prompt/decode/batchsize-shape combinations...") - print(f"[SpyreWorker] Warming up for prompt length {prompt_len}, " - f"decoding {num_decode_tokens} tokens with batch " - f"size {batch_size}") - with _maybe_warmup_context(): - self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens, - restricted_tokens, batch_size) - all_warmup_end_t = time.time() - all_warmup_total_t = all_warmup_end_t - all_warmup_start_t - self.perf_metrics.log("total warmup time", all_warmup_total_t) - # No more perf metric are captured (so far) after warmup, cleanup now. - del self.perf_metrics - print(f"[SpyreWorker] All warmups for " - f"{len(wup_new_tokens)} different " - f"prompt/decode/batchsize-shape combinations finished. " - f"Total warmup time {all_warmup_total_t}s.") - - def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, - special_token_ids, batch_size): - # warmup the model - warmup_start_t = time.time() - # NOTE(ngl): empty tensor causes spyre to hang, so using - # randint without 0 and the eos and bos token - - # Create a list of valid values between 1 (inclusive) and vocab - # size (exclusive) by excluding the eos and bos token ids - # (in special_token_ids) - vocab_size = self.model_runner.vocab_size - valid_token_ids = [ - i for i in range(1, vocab_size) if i not in set(special_token_ids) - ] - # Convert to tensor for sampling - valid_token_ids_tensor = torch.tensor(valid_token_ids, - dtype=torch.long, - device=torch.device("cpu")) - # Sample from the valid token ids - warmup_tokens_tensor = valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (batch_size, prompt_len))] - - extra_kwargs = {} - if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn": - # Bug in 2.3.1 fixed in 2.4.1 for SDPA flash cpu - # impl when padding too much - extra_kwargs["attn_algorithm"] = "math" - - print(f"[SpyreWorker] warmup for prompt length " - f"{prompt_len} and max output tokens {num_decode_tokens}.") - - print("[SpyreWorker] warmup 1/2...") - # warmup 1st forward pass - self._warmup_model_forward_pass(warmup_tokens_tensor, - valid_token_ids_tensor, prompt_len, - num_decode_tokens, batch_size, - extra_kwargs) - self.perf_metrics.log("warmup 1 time", - time.time() - warmup_start_t, - batch_size=batch_size, - max_tokens=num_decode_tokens, - prompt_len=prompt_len) - - print("[SpyreWorker] warmup 2/2...") - # warmup 2nd forward pass - warmup2_start_t = time.time() - self._warmup_model_forward_pass(warmup_tokens_tensor, - valid_token_ids_tensor, prompt_len, - num_decode_tokens, batch_size, - extra_kwargs) - - warmup_end_t = time.time() - warmup_total_t = warmup_end_t - warmup_start_t - self.perf_metrics.log("warmup 2 time", - time.time() - warmup2_start_t, - batch_size=batch_size, - max_tokens=num_decode_tokens, - prompt_len=prompt_len) - print("[SpyreWorker] ... warmup finished.") - print(f"\twarmup took {warmup_total_t}s (for prompt length" - f"{prompt_len} and max output tokens {num_decode_tokens})") - - def _warmup_model_forward_pass(self, warmup_tokens_tensor, - valid_token_ids_tensor, prompt_len, - num_decode_tokens, batch_size, - extra_kwargs): - # padding warmup tokens to obtain the - # corresponding position ids and mask - warmup_tokens_pad, self.model_runner._position_ids, \ - self.model_runner._mask = self.model_runner.pad_input_ids( - warmup_tokens_tensor, min_pad_length=prompt_len) - - logits, past_key_value_states = self.model_runner._raw_model_forward( - warmup_tokens_pad, - position_ids=self.model_runner._position_ids, - mask=self.model_runner._mask, - past_key_value_states=None, - use_cache=True, - only_last_token=True, - **extra_kwargs) - # decoding - for i in range(num_decode_tokens - 1): - # sampling next input token from vocab without bos and eos tokens - decode_tokens = valid_token_ids_tensor[torch.randint( - 0, len(valid_token_ids_tensor), (batch_size, 1))] - - # update mask and position_ids - self.model_runner._update_mask() - self.model_runner._update_position_ids() - - if past_key_value_states is not None: - for layer in past_key_value_states: - for tensor in layer: - torch._dynamo.mark_dynamic(tensor, 2) - - logits, past_key_value_states = self.model_runner.\ - _raw_model_forward( - decode_tokens, - position_ids=self.model_runner._position_ids, - mask=self.model_runner._mask, - past_key_value_states=past_key_value_states, - use_cache=True, - only_last_token=True, - **extra_kwargs) - - def determine_num_available_blocks(self) -> tuple[int, int]: - """Determine the number of available KV blocks. - - Swapping is not yet supported, so always return num_cpu_blocks=0. - - We configure num_gpu_blocks to be equal to max_num_seqs. - """ - # Set the number of GPU blocks to be the same as the maximum number of - # sequences that can be processed in a single batch. This is equivalent - # to schedule without PagedAttention. - num_gpu_blocks = self.scheduler_config.max_num_seqs - - # Swap not yet supported with Spyre backend. - num_cpu_blocks = 0 - - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache. - """ - - # Different values are not tested. - assert num_cpu_blocks == 0 - assert num_gpu_blocks == self.scheduler_config.max_num_seqs - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - def get_cache_block_size_bytes(self) -> int: - """Determine the size in bytes of a cache block. - - This is required for speculative decoding; it is not yet implemented. - """ - raise NotImplementedError - - @property - def do_metadata_broadcast(self) -> bool: - return True - - @property - def kv_cache(self) -> Optional[list[list[torch.Tensor]]]: - return None - - def prepare_worker_input( - self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - return WorkerInput(num_seq_groups=len( - execute_model_req.seq_group_metadata_list), ) - - def execute_worker(self, worker_input: WorkerInput) -> None: - pass