diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py new file mode 100644 index 0000000000000..26a2084b131fa --- /dev/null +++ b/vllm/v1/worker/block_table.py @@ -0,0 +1,78 @@ +from typing import List + +import numpy as np +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class BlockTable: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + pin_memory: bool, + device: torch.device, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.pin_memory = pin_memory + self.device = device + + self.block_table = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32, + ) + self.block_table_cpu = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.block_table_np = self.block_table_cpu.numpy() + self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + + def append_row( + self, + row_idx: int, + start: int, + block_ids: List[int], + ) -> None: + num_blocks = len(block_ids) + self.block_table_np[row_idx, start:start + num_blocks] = block_ids + self.num_blocks_per_row[row_idx] = start + num_blocks + + def add_row(self, row_idx: int, block_ids: List[int]) -> None: + self.append_row(row_idx, 0, block_ids) + + def move_row(self, src: int, tgt: int) -> None: + num_blocks = self.num_blocks_per_row[src] + self.block_table_np[tgt, :num_blocks] = self.block_table_np[ + src, :num_blocks] + self.num_blocks_per_row[tgt] = num_blocks + + def commit(self, num_reqs: int) -> None: + self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], + non_blocking=True) + + def clear(self) -> None: + self.block_table.fill_(0) + self.block_table_cpu.fill_(0) + + def get_device_tensor(self) -> torch.Tensor: + """Ruturns the device tensor of the block table.""" + return self.block_table + + def get_cpu_tensor(self) -> torch.Tensor: + """Returns the CPU tensor of the block table.""" + return self.block_table_cpu + + def get_numpy_array(self) -> np.ndarray: + """Returns the numpy array of the block table.""" + return self.block_table_np diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index f8a1427c6c26c..40494e64b22f0 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -9,6 +9,7 @@ from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.multimodal.inputs import PlaceholderRange @@ -70,19 +71,14 @@ def __init__( self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) - # Attention-related. - self.block_table = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32, - ) - self.block_table_cpu_tensor = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, + # Block table. + self.block_table = BlockTable( + max_num_reqs=max_num_reqs, + max_model_len=max_model_len, + max_num_blocks_per_req=max_num_blocks_per_req, pin_memory=pin_memory, + device=device, ) - self.block_table_cpu = self.block_table_cpu_tensor.numpy() # Sampling-related. self.temperature = torch.empty((max_num_reqs, ), @@ -193,8 +189,7 @@ def add_request( self.num_tokens[req_index] = request.num_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - num_blocks = len(request.block_ids) - self.block_table_cpu[req_index, :num_blocks] = request.block_ids + self.block_table.add_row(req_index, request.block_ids) sampling_params = request.sampling_params self.temperature_cpu[req_index] = sampling_params.temperature @@ -300,9 +295,7 @@ def condense(self, empty_req_indices: List[int]) -> None: self.num_prompt_tokens[last_req_index] self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] - # TODO(woosuk): Optimize the copy of block_table_cpu. - self.block_table_cpu[empty_index] = self.block_table_cpu[ - last_req_index] + self.block_table.move_row(last_req_index, empty_index) self.temperature_cpu[empty_index] = self.temperature_cpu[ last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 294c76cfb680e..31e693235f99f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -211,10 +211,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if num_new_blocks == 0: continue start_index = len(req_state.block_ids) - end_index = start_index + num_new_blocks req_state.block_ids.extend(req_data.new_block_ids) - self.input_batch.block_table_cpu[ - req_index, start_index:end_index] = req_data.new_block_ids + self.input_batch.block_table.append_row(req_index, start_index, + req_data.new_block_ids) req_ids_to_add: List[str] = [] # Add new requests to the cached states. @@ -275,9 +274,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. - self.input_batch.block_table[:num_reqs].copy_( - self.input_batch.block_table_cpu_tensor[:num_reqs], - non_blocking=True) + self.input_batch.block_table.commit(num_reqs) # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. @@ -333,8 +330,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - block_numbers = (self.input_batch.block_table_cpu_tensor.flatten() - [block_table_indices].numpy()) + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, block_offsets, @@ -450,7 +447,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_start_loc=seq_start_loc, - block_table=self.input_batch.block_table[:num_reqs], + block_table=( + self.input_batch.block_table.get_device_tensor()[:num_reqs]), slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len,