From 855d51751eaa87a79b49764e569425901ac96d31 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 17 Jul 2025 15:42:11 +0800 Subject: [PATCH 1/3] feat: move init_req_to_token_indexes and copy_kv_index_to_req to alloc fun --- lightllm/common/basemodel/basemodel.py | 66 +------------------ lightllm/common/basemodel/cuda_graph.py | 4 +- lightllm/common/infer_utils.py | 4 +- lightllm/common/mem_manager.py | 28 +++++++- lightllm/common/req_manager.py | 1 + .../server/router/model_infer/infer_batch.py | 6 +- .../generic_padded_pre_process.py | 8 ++- .../mode_backend/generic_pre_process.py | 6 +- .../benchmark/static_inference/model_infer.py | 6 +- .../static_inference/model_infer_mtp.py | 8 ++- 10 files changed, 58 insertions(+), 79 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 4516e18c3..99dbf6ccf 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -11,9 +11,7 @@ from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.common.mem_manager import MemoryManager from lightllm.common.req_manager import ReqManager -from lightllm.common.infer_utils import init_req_to_token_indexes from lightllm.common.build_utils import repair_config -from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.common.quantization import Quantcfg @@ -332,14 +330,6 @@ def _prefill( model_input: ModelInput, ): infer_state = self._create_inferstate(model_input) - init_req_to_token_indexes( - self.req_manager.req_to_token_indexs, - model_input.b_req_idx, - model_input.b_seq_len, - infer_state.b_ready_cache_len, - model_input.max_len_in_batch, - infer_state.mem_index, - ) infer_state.init_some_extra_state(self, model_input.input_ids) return self._context_forward(model_input.input_ids, infer_state) @@ -360,12 +350,6 @@ def _decode( find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) infer_state = self._create_inferstate(padded_model_input) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.mem_index, - ) infer_state.init_some_extra_state(self, padded_model_input.input_ids) if self.graph.need_capture(find_graph_batch_size): @@ -381,12 +365,6 @@ def _decode( ) else: infer_state = self._create_inferstate(model_input) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.mem_index, - ) infer_state.init_some_extra_state(self, model_input.input_ids) model_output = self._token_forward(model_input.input_ids, infer_state) @@ -471,25 +449,9 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids infer_state0 = self._create_inferstate(model_input0, 0) - init_req_to_token_indexes( - self.req_manager.req_to_token_indexs, - model_input0.b_req_idx, - model_input0.b_seq_len, - infer_state0.b_ready_cache_len, - model_input0.max_len_in_batch, - infer_state0.mem_index, - ) infer_state0.init_some_extra_state(self, input_ids0) infer_state1 = self._create_inferstate(model_input1, 1) - init_req_to_token_indexes( - self.req_manager.req_to_token_indexs, - model_input1.b_req_idx, - model_input1.b_seq_len, - infer_state1.b_ready_cache_len, - model_input1.max_len_in_batch, - infer_state1.mem_index, - ) infer_state1.init_some_extra_state(self, input_ids1) model_output0, model_output1 = self._overlap_tpsp_context_forward( @@ -531,20 +493,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size) padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size) infer_state0 = self._create_inferstate(padded_model_input0, 0) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state0.b_req_idx, - infer_state0.b_seq_len, - infer_state0.mem_index, - ) infer_state0.init_some_extra_state(self, padded_model_input0.input_ids) infer_state1 = self._create_inferstate(padded_model_input1, 1) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state1.b_req_idx, - infer_state1.b_seq_len, - infer_state1.mem_index, - ) infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) if self.graph.need_capture(find_graph_batch_size): @@ -569,20 +519,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size) else: infer_state0 = self._create_inferstate(model_input0, 0) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state0.b_req_idx, - infer_state0.b_seq_len, - infer_state0.mem_index, - ) infer_state0.init_some_extra_state(self, model_input0.input_ids) infer_state1 = self._create_inferstate(model_input1, 1) - copy_kv_index_to_req( - self.req_manager.req_to_token_indexs, - infer_state1.b_req_idx, - infer_state1.b_seq_len, - infer_state1.mem_index, - ) infer_state1.init_some_extra_state(self, model_input1.input_ids) model_output0, model_output1 = self._overlap_tpsp_token_forward( @@ -683,10 +621,12 @@ def _check_max_len_infer(self): logger.info("begin check max_len infer") dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda") b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") - mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") b_seq_len[:] = self.batch_max_tokens b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") + mem_indexes = self.mem_manager.alloc( + len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() total_token_num = self.batch_max_tokens b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") model_input = ModelInput( diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 07792865e..db77298c7 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -196,13 +196,13 @@ def warmup(self, model): total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda() model_input = ModelInput( batch_size=batch_size, @@ -252,13 +252,13 @@ def warmup_overlap(self, model): total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len).cuda() micro_batch = ModelInput( is_prefill=False, diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index da2f35e08..c3f980ecd 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -1,6 +1,4 @@ -def init_req_to_token_indexes( - req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, max_len_in_batch, alloc_mem_index -): +def init_req_to_token_indexes(req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, alloc_mem_index): start_index = 0 b_seq_len_numpy = b_seq_len.cpu().numpy() b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy() diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 4142ce4aa..aea5bc405 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -12,6 +12,8 @@ from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id +from lightllm.common.infer_utils import init_req_to_token_indexes +from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req logger = init_logger(__name__) @@ -52,6 +54,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False layer_num, ) self.HOLD_TOKEN_MEMINDEX = self.size + self.req_to_token_indexs = None def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) @@ -243,7 +246,9 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to def _free_buffers(self): self.kv_buffer = None - def alloc(self, need_size) -> torch.Tensor: + def alloc( + self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False + ) -> torch.Tensor: if need_size > self.mark_end - self.mark_start: logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") assert False, "error alloc state" @@ -255,8 +260,29 @@ def alloc(self, need_size) -> torch.Tensor: self.can_use_mem_size -= need_size self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + if self.req_to_token_indexs is not None: + assert b_req_idx is not None and b_seq_len is not None, "b_req_idx and b_seq_len must be provided" + if is_prefill: + init_req_to_token_indexes( + self.req_to_token_indexs, + b_req_idx, + b_seq_len, + b_ready_cache_len, + ans, + ) + else: + copy_kv_index_to_req( + self.req_to_token_indexs, + b_req_idx.cuda(), + b_seq_len.cuda(), + ans.cuda(), + ) return ans + def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor): + self.req_to_token_indexs[req_idx, start:end] = values + def free(self, free_index: Union[torch.Tensor, List[int]]): """_summary_ diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 59f607a01..135cbbc22 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -62,6 +62,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.req_to_token_indexs = torch.zeros( (max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda" ) + mem_manager.req_to_token_indexs = self.req_to_token_indexs self.mem_manager = mem_manager self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num) self.max_request_num = max_request_num diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 67d69aa38..1add8d98b 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -340,7 +340,9 @@ def _match_radix_cache(self): self.shared_kv_node = share_node ready_cache_len = share_node.node_prefix_total_len # 从 cpu 到 gpu 是流内阻塞操作 - g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor + g_infer_context.req_manager.mem_manager.set_prefix_cache_to_req( + self.req_idx, 0, ready_cache_len, value_tensor + ) self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 @@ -461,7 +463,7 @@ def diverse_copy(self, req_manager, is_prefill): req = g_infer_context.requests_mapping[req_id] req.finish_status.set_status(FinishStatus.NO_FINISH) input_len = req.get_chuncked_input_token_len() - req_manager.req_to_token_indexs[req.req_idx][prefix_len:input_len] = cache_token_id + req_manager.mem_manager.set_prefix_cache_to_req(req.req_idx, prefix_len, input_len, cache_token_id) assert input_len == pre_input_len diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 10090a576..5059ad27f 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -78,7 +78,9 @@ def padded_prepare_prefill_inputs( g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len, True + ) g_infer_state_lock.release() if padded_req_num > 0: @@ -163,7 +165,9 @@ def padded_prepare_decode_inputs( g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0] - padded_req_num) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + b_seq_len.shape[0] - padded_req_num, b_req_idx, b_seq_len + ) g_infer_state_lock.release() if padded_req_num > 0: diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index d5bba1ae5..1c81ebfcd 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -56,7 +56,9 @@ def prepare_prefill_inputs( g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc( + input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + ) g_infer_state_lock.release() model_input = ModelInput( @@ -112,7 +114,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0]) + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0], b_req_idx, b_seq_len) g_infer_state_lock.release() model_input = ModelInput( diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 73a99ff28..f59e01a6f 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -242,7 +242,9 @@ def run_forward_once( b_seq_len[i] = input_len total_token_num = batch_size * input_len - mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + mem_indexes = model_part.req_manager.mem_manager.alloc( + test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() rank_id = model_kvargs["rank_id"] @@ -303,7 +305,7 @@ def run_forward_once( step_start = time.time() total_token_num += batch_size b_seq_len += 1 - mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]).cuda() + mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0], b_req_idx, b_seq_len).cuda() max_len_in_batch = input_len + i + 1 logits = decode_fn( model_part, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index eb36bc873..3df56cdcc 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -126,7 +126,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ b_seq_len[i] = input_len total_token_num = input_len * batch_size - mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + mem_indexes = main_model.req_manager.mem_manager.alloc( + test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True + ).cuda() # Main model Prefill model_input = ModelInput( batch_size=batch_size, @@ -193,7 +195,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda() + mem_indexes = main_model.req_manager.mem_manager.alloc( + batch_size * (len(draft_models) + 1), nopad_b_seq_idx, nopad_b_seq_len + ).cuda() model_input = ModelInput( batch_size=batch_size * (len(draft_models) + 1), From a386b3d0ce86c3e10cb29357c6fa97ae95eb6176 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 31 Jul 2025 15:11:44 +0800 Subject: [PATCH 2/3] feat: add page_size_variable mode for fa3 backend --- lightllm/common/basemodel/basemodel.py | 62 +++ lightllm/common/infer_utils.py | 4 +- lightllm/common/mem_manager.py | 20 - lightllm/common/mem_utils.py | 4 + .../common/page_size_variable_mem_manager.py | 173 +++++++ lightllm/common/req_manager.py | 10 +- .../llama/flashattention_infer_struct.py | 42 +- .../layer_infer/transformer_layer_infer.py | 76 ++- .../dynamic_prompt/paged_radix_cache.py | 462 ++++++++++++++++++ .../router/dynamic_prompt/radix_cache.py | 4 +- .../model_infer/mode_backend/base_backend.py | 4 +- .../generic_padded_pre_process.py | 6 +- .../mode_backend/generic_pre_process.py | 6 +- lightllm/utils/envs_utils.py | 9 + 14 files changed, 835 insertions(+), 47 deletions(-) create mode 100755 lightllm/common/page_size_variable_mem_manager.py create mode 100644 lightllm/server/router/dynamic_prompt/paged_radix_cache.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 99dbf6ccf..dbd4f780c 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -11,7 +11,9 @@ from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.common.mem_manager import MemoryManager from lightllm.common.req_manager import ReqManager +from lightllm.common.infer_utils import init_req_to_token_indexes from lightllm.common.build_utils import repair_config +from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.common.quantization import Quantcfg @@ -330,6 +332,14 @@ def _prefill( model_input: ModelInput, ): infer_state = self._create_inferstate(model_input) + init_req_to_token_indexes( + self.req_manager.req_to_token_indexs, + model_input.b_req_idx, + model_input.b_seq_len, + infer_state.b_ready_cache_len, + model_input.max_len_in_batch, + infer_state.mem_index, + ) infer_state.init_some_extra_state(self, model_input.input_ids) return self._context_forward(model_input.input_ids, infer_state) @@ -350,6 +360,12 @@ def _decode( find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) infer_state = self._create_inferstate(padded_model_input) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.mem_index, + ) infer_state.init_some_extra_state(self, padded_model_input.input_ids) if self.graph.need_capture(find_graph_batch_size): @@ -365,6 +381,12 @@ def _decode( ) else: infer_state = self._create_inferstate(model_input) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.mem_index, + ) infer_state.init_some_extra_state(self, model_input.input_ids) model_output = self._token_forward(model_input.input_ids, infer_state) @@ -449,9 +471,25 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids infer_state0 = self._create_inferstate(model_input0, 0) + init_req_to_token_indexes( + self.req_manager.req_to_token_indexs, + model_input0.b_req_idx, + model_input0.b_seq_len, + infer_state0.b_ready_cache_len, + model_input0.max_len_in_batch, + infer_state0.mem_index, + ) infer_state0.init_some_extra_state(self, input_ids0) infer_state1 = self._create_inferstate(model_input1, 1) + init_req_to_token_indexes( + self.req_manager.req_to_token_indexs, + model_input1.b_req_idx, + model_input1.b_seq_len, + infer_state1.b_ready_cache_len, + model_input1.max_len_in_batch, + infer_state1.mem_index, + ) infer_state1.init_some_extra_state(self, input_ids1) model_output0, model_output1 = self._overlap_tpsp_context_forward( @@ -493,8 +531,20 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size) padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size) infer_state0 = self._create_inferstate(padded_model_input0, 0) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state0.b_req_idx, + infer_state0.b_seq_len, + infer_state0.mem_index, + ) infer_state0.init_some_extra_state(self, padded_model_input0.input_ids) infer_state1 = self._create_inferstate(padded_model_input1, 1) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state1.b_req_idx, + infer_state1.b_seq_len, + infer_state1.mem_index, + ) infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) if self.graph.need_capture(find_graph_batch_size): @@ -519,8 +569,20 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size) else: infer_state0 = self._create_inferstate(model_input0, 0) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state0.b_req_idx, + infer_state0.b_seq_len, + infer_state0.mem_index, + ) infer_state0.init_some_extra_state(self, model_input0.input_ids) infer_state1 = self._create_inferstate(model_input1, 1) + copy_kv_index_to_req( + self.req_manager.req_to_token_indexs, + infer_state1.b_req_idx, + infer_state1.b_seq_len, + infer_state1.mem_index, + ) infer_state1.init_some_extra_state(self, model_input1.input_ids) model_output0, model_output1 = self._overlap_tpsp_token_forward( diff --git a/lightllm/common/infer_utils.py b/lightllm/common/infer_utils.py index c3f980ecd..da2f35e08 100644 --- a/lightllm/common/infer_utils.py +++ b/lightllm/common/infer_utils.py @@ -1,4 +1,6 @@ -def init_req_to_token_indexes(req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, alloc_mem_index): +def init_req_to_token_indexes( + req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, max_len_in_batch, alloc_mem_index +): start_index = 0 b_seq_len_numpy = b_seq_len.cpu().numpy() b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy() diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index aea5bc405..f3cf0419d 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -12,8 +12,6 @@ from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id -from lightllm.common.infer_utils import init_req_to_token_indexes -from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req logger = init_logger(__name__) @@ -260,24 +258,6 @@ def alloc( self.can_use_mem_size -= need_size self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - if self.req_to_token_indexs is not None: - assert b_req_idx is not None and b_seq_len is not None, "b_req_idx and b_seq_len must be provided" - if is_prefill: - init_req_to_token_indexes( - self.req_to_token_indexs, - b_req_idx, - b_seq_len, - b_ready_cache_len, - ans, - ) - else: - copy_kv_index_to_req( - self.req_to_token_indexs, - b_req_idx.cuda(), - b_seq_len.cuda(), - ans.cuda(), - ) return ans def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor): diff --git a/lightllm/common/mem_utils.py b/lightllm/common/mem_utils.py index dfb8e849d..5f3ee6164 100644 --- a/lightllm/common/mem_utils.py +++ b/lightllm/common/mem_utils.py @@ -4,6 +4,7 @@ from lightllm.common.export_calibration_mem_manager import ExportCalibrationMemoryManager from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager +from lightllm.common.page_size_variable_mem_manager import PageSizeVariableMemoryManager from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -28,6 +29,9 @@ def select_mem_manager_class(mode): elif "export_fp8kv_calibration" in mode: memory_manager_class = ExportCalibrationMemoryManager logger.info("Using mode export fp8kv calibration") + elif "page_size_variable" in mode: + memory_manager_class = PageSizeVariableMemoryManager + logger.info("Page size will be variable") else: memory_manager_class = MemoryManager logger.info("Model kv cache using mode normal") diff --git a/lightllm/common/page_size_variable_mem_manager.py b/lightllm/common/page_size_variable_mem_manager.py new file mode 100755 index 000000000..095648c01 --- /dev/null +++ b/lightllm/common/page_size_variable_mem_manager.py @@ -0,0 +1,173 @@ +import torch +import numpy as np +from .mem_manager import MemoryManager +from typing import List, Union +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_page_size + + +def cdiv(a, b): + return (a + b - 1) // b + + +logger = init_logger(__name__) + + +class PageSizeVariableMemoryManager(MemoryManager): + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction) + self.req_to_page_indexs = None + page_size = get_page_size() + self.page_idx_pool = torch.arange( + 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_page_start = 0 + self.can_use_page_size = cdiv(self.size, page_size) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + self.kv_buffer = torch.empty( + (layer_num, cdiv(size, get_page_size()) * get_page_size(), 2 * head_num, head_dim), + dtype=dtype, + device="cuda", + ) + + # 要求长度必须是page_size的整数倍,page内token索引必须连续 + def check_cache_page_valid(self, values: torch.Tensor): + end = len(values) + assert end % self.page_size == 0, "Values length must be a multiple of page size" + total_pages = end // self.page_size + for page_idx in range(total_pages): + values_start = page_idx * self.page_size + values_end = min((page_idx + 1) * self.page_size, end) + page_token_idxs = values[values_start:values_end] + if len(page_token_idxs) > 1: + expected_idxs = torch.arange( + page_token_idxs[0], + page_token_idxs[0] + len(page_token_idxs), + dtype=page_token_idxs.dtype, + device=page_token_idxs.device, + ) + if not torch.equal(page_token_idxs, expected_idxs): + return False + return True + + def set_prefix_cache_to_req(self, req_idx: int, start: int, end: int, values: torch.Tensor): + # assert self.check_cache_page_valid(values), "Values must be valid for page size" + page_size = get_page_size() + self.req_to_page_indexs[req_idx, start // page_size : end // page_size] = values[::page_size] // page_size + self.req_to_token_indexs[req_idx, start:end] = values + + def expand_by_page_size(self, b_token_len, page_size): + # 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> page_len = [4,4,1,4,4,1,4,4,1], page_size = 4 + b_page_len = cdiv(b_token_len, page_size) + need_pages_num = b_page_len.sum() + p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device) + cumsum_pages = torch.cumsum(b_page_len, dim=0) + last_page_positions = cumsum_pages - 1 + remainders = b_token_len - (b_page_len - 1) * page_size + p_token_len[last_page_positions] = remainders + return need_pages_num, b_page_len, p_token_len + + def get_paged_token_indexs(self, b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill): + if is_prefill: + b_req_idx = b_req_idx.cuda() + b_seq_len = b_seq_len.cuda() + b_ready_cache_len = b_ready_cache_len.cuda() + + b_token_len = b_seq_len - b_ready_cache_len + total_pages_needed, b_page_len, p_token_len = self.expand_by_page_size(b_token_len, page_size) + if self.can_use_page_size < total_pages_needed: + raise RuntimeError( + f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {total_pages_needed}" + ) + + allocated_pages = self.page_idx_pool[ + self.mark_page_start : self.mark_page_start + total_pages_needed + ].cuda() + + def get_offsets_by_length(b_len, max_len): + # 例:b_len = [3,4,5] -> [0,1,2,0,1,2,3,0,1,2,3,4] + offsets = torch.arange(max_len, dtype=b_len.dtype, device=b_len.device) + offset_mask = offsets.unsqueeze(0) < b_len.unsqueeze(1) + return torch.masked_select(offsets, offset_mask) + + page_offsets = get_offsets_by_length(b_page_len, b_page_len.max()) + token_offsets = get_offsets_by_length(p_token_len, page_size) + + # 更新req_to_page_indexs, b_ready_cache_len必整除page_size + page_starts = b_ready_cache_len // page_size + req_id = torch.repeat_interleave( + torch.arange(len(b_req_idx), dtype=b_token_len.dtype, device=b_token_len.device), b_page_len + ) + self.req_to_page_indexs[b_req_idx[req_id], page_starts[req_id] + page_offsets] = allocated_pages + + self.mark_page_start += total_pages_needed + self.can_use_page_size -= total_pages_needed + page_bases = allocated_pages * page_size + return torch.repeat_interleave(page_bases, p_token_len) + token_offsets + else: + b_seq_len = b_seq_len.cuda() + b_req_idx = b_req_idx.cuda() + need_new_page_mask = (b_seq_len - 1) % page_size == 0 + new_pages_num = need_new_page_mask.sum() + if self.can_use_page_size < new_pages_num: + raise RuntimeError( + f"No available pages for alloc. remaining: {self.can_use_page_size}, needed: {new_pages_num}" + ) + + token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device) + if new_pages_num > 0: + new_pages = self.page_idx_pool[self.mark_page_start : self.mark_page_start + new_pages_num].cuda() + self.mark_page_start += new_pages_num + self.can_use_page_size -= new_pages_num + token_idxs[need_new_page_mask] = new_pages * page_size + + # 需要更新req_to_page_indexs + new_page_req_indices = b_req_idx[need_new_page_mask] + page_positions = (b_seq_len[need_new_page_mask] - 1) // page_size + self.req_to_page_indexs[new_page_req_indices, page_positions] = new_pages + + mask = ~need_new_page_mask + if mask.any(): + seq_lens = b_seq_len[mask] + token_idxs[mask] = ( + self.req_to_token_indexs[b_req_idx[mask], seq_lens - 2] // page_size * page_size + + (seq_lens - 1) % page_size + ) + return token_idxs + + def alloc(self, need_size, b_req_idx, b_seq_len, b_ready_cache_len=None, is_prefill=False) -> torch.Tensor: + page_size = get_page_size() + token_idxs = self.get_paged_token_indexs(b_req_idx, page_size, b_seq_len, b_ready_cache_len, is_prefill) + self.can_use_mem_size -= need_size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + return token_idxs + + def free(self, free_index: Union[torch.Tensor, List[int]]): + self.can_use_mem_size += len(free_index) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + page_size = get_page_size() + if isinstance(free_index, list): + free_index = torch.tensor(free_index, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True) + + if len(free_index) == 0: + return + + page_indices = free_index // page_size + unique_pages = torch.unique(page_indices) + for page_idx in sorted(unique_pages, reverse=True): # 逆序放回,保持池的相对顺序 + self.mark_page_start -= 1 + self.page_idx_pool[self.mark_page_start] = page_idx + self.can_use_page_size += 1 + + return + + def free_all(self): + super().free_all() + page_size = get_page_size() + self.mark_page_start = 0 + self.can_use_page_size = cdiv(self.size, page_size) + self.page_idx_pool = torch.arange( + 0, cdiv(self.size, page_size), dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 135cbbc22..0786bbb08 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -5,7 +5,7 @@ from typing import List, Optional from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter -from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_page_size from lightllm.utils.config_utils import get_vocab_size logger = init_logger(__name__) @@ -63,6 +63,14 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana (max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda" ) mem_manager.req_to_token_indexs = self.req_to_token_indexs + if hasattr(mem_manager, "req_to_page_indexs"): + page_size = get_page_size() + self.req_to_page_indexs = torch.zeros( + (max_request_num + 1, (max_sequence_length + page_size - 1) // page_size), + dtype=torch.int32, + device="cuda", + ) + mem_manager.req_to_page_indexs = self.req_to_page_indexs self.mem_manager = mem_manager self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num) self.max_request_num = max_request_num diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index 02da92d9e..b4f7bd096 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -3,12 +3,16 @@ import numpy as np import torch.distributed as dist from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.utils.dist_utils import get_current_device_id from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index from lightllm.common.basemodel.batch_objs import ModelInput +def cdiv(a, b): + return (a + b - 1) // b + + class FlashAttentionStateInfo(LlamaInferStateInfo): _shared_page_table_buffer = None @@ -29,32 +33,34 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() - self.page_table = torch.empty( - (self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device - ) - self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len]) + length = cdiv(self.max_seq_len, get_page_size()) + self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device) + if "page_size_variable" in model.mode: + self.page_table.copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) + else: + self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, :length]) else: # Meta information of flashattention for decoding self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() max_seq_len_k = self.max_kv_seq_len if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_buffer = FlashAttentionStateInfo.get_page_table_buffer( - model.graph_max_batch_size, model.graph_max_len_in_batch + page_size = get_page_size() + length = cdiv(model.graph_max_len_in_batch, page_size) + page_buffer = FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length) + self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape( + self.batch_size, length ) - self.page_table = page_buffer[self.microbatch_index][ - : self.batch_size * model.graph_max_len_in_batch - ].reshape(self.batch_size, model.graph_max_len_in_batch) else: - self.page_table = torch.empty( - (self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device - ) + length = cdiv(self.max_len_in_batch, get_page_size()) + self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device) - self.page_table[:, :max_seq_len_k].copy_( - model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k], - non_blocking=True, - ) - self.page_table[:, max_seq_len_k:].fill_(0) + length = cdiv(max_seq_len_k, get_page_size()) + if "page_size_variable" in model.mode: + self.page_table[:, :length].copy_(model.req_manager.req_to_page_indexs[self.b_req_idx, :length]) + else: + self.page_table[:, :length].copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, :length]) + self.page_table[:, length:].fill_(0) if "offline_calibration_fp8kv" in model.mode: if self.is_prefill: diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 4b06a75c3..bf1d64f62 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -27,7 +27,7 @@ from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops @@ -87,6 +87,14 @@ def _bind_attention(self): self._copy_kv_to_mem_cache = partial( LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self ) + elif "page_size_variable" in self.mode: + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._paged_context_attention_flashattention, self + ) + self._token_attention_kernel = partial( + LlamaTransformerLayerInfer._paged_token_decode_attention_flashattention, self + ) + self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif not self.mode: self._context_attention_kernel = partial( LlamaTransformerLayerInfer._context_attention_flashattention, self @@ -317,6 +325,39 @@ def _context_attention_kernel_ppl_int8kv( ) return o_tensor + def _paged_context_attention_flashattention( + self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None + ): + page_size = get_page_size() + cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( + -1, page_size, self.tp_k_head_num_, self.head_dim_ + ) + cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ].reshape(-1, page_size, self.tp_v_head_num_, self.head_dim_) + q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) + k_descale, v_descale = None, None # disable quantization + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) + o = flash_attn_with_kvcache( + q=q, + k_cache=cache_k, + v_cache=cache_v, + page_table=infer_state.page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=infer_state.q_max_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + ) + return o + def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ @@ -824,6 +865,39 @@ def _token_decode_attention_gqa_flashdecoding_vsm( alloc_tensor_func=self.alloc_tensor, ) + def _paged_token_decode_attention_flashattention( + self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None + ): + page_size = get_page_size() + cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( + -1, page_size, self.tp_k_head_num_, self.head_dim_ + ) + cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ].reshape(-1, page_size, self.tp_v_head_num_, self.head_dim_) + q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) + k_descale, v_descale = None, None # disable quantization + Lq = q.shape[-1] + sm_scale = 1.0 / (Lq ** 0.5) + o = flash_attn_with_kvcache( + q=q, + k_cache=cache_k, + v_cache=cache_v, + page_table=infer_state.page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=sm_scale, + causal=False, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + ) + return o + def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( -1, 1, self.tp_k_head_num_, self.head_dim_ diff --git a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py new file mode 100644 index 000000000..33210b90d --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py @@ -0,0 +1,462 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/router/radix_cache.py +import torch +import numpy as np +from typing import Tuple, Dict, Set, List +from sortedcontainers import SortedSet +from .shared_arr import SharedArray +from lightllm.utils.envs_utils import get_page_size + + +class UniqueTimeIdGenerator: + def __init__(self): + self.counter = 0 + + def generate_time_id(self): + self.counter += 1 + return self.counter + + +time_gen = UniqueTimeIdGenerator() + + +class TreeNode: + def __init__(self): + self.children: Dict[int, TreeNode] = {} # page_hash -> TreeNode + self.parent: TreeNode = None + self.token_id_key: torch.Tensor = None + self.token_mem_index_value: torch.Tensor = None # 用于记录存储的 token_index 为每个元素在 token mem 中的index位置 + self.ref_counter = 0 + self.time_id = time_gen.generate_time_id() # 用于标识时间周期 + + self.node_value_len = 0 + self.node_prefix_total_len = 0 + self.total_children_count = 0 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + def get_compare_key(self): + return (0 if self.ref_counter == 0 else 1, self.total_children_count, self.time_id) + + def _compute_key(self, tokens: torch.Tensor) -> int: + page_tokens = tokens[: self.page_size] + return page_tokens.item() if self.page_size == 1 else hash(page_tokens.cpu().numpy().tobytes()) + + def find_matched_child(self, token_id_key: torch.Tensor) -> Tuple["TreeNode", int]: + target_key = self._compute_key(token_id_key) + if target_key in self.children: + child = self.children[target_key] + prefix_len = match(token_id_key, child.token_id_key) + # 只匹配page_size的整数倍长度 + if self.page_size > 1: + if prefix_len % self.page_size != 0: + if self._page_size_is_power_of_2: + # 位运算加速 + prefix_len = prefix_len & ~self._page_size_mask + else: + prefix_len = (prefix_len // self.page_size) * self.page_size + if prefix_len == 0: + return None, 0 + return child, prefix_len + + return None, 0 + + def split_node(self, prefix_len): + split_parent_node = TreeNode() + split_parent_node.parent = self.parent + self.parent.children[self._compute_key(self.token_id_key)] = split_parent_node + + split_parent_node.token_id_key = self.token_id_key[0:prefix_len] + split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len] + split_parent_node.children = {} + + remaining_tokens = self.token_id_key[prefix_len:] + split_parent_node.children[self._compute_key(remaining_tokens)] = self + split_parent_node.ref_counter = self.ref_counter + split_parent_node.total_children_count = 1 + + new_len = len(split_parent_node.token_mem_index_value) + split_parent_node.node_value_len = new_len + split_parent_node.node_prefix_total_len = split_parent_node.parent.node_prefix_total_len + new_len + + self.token_id_key = remaining_tokens + self.token_mem_index_value = self.token_mem_index_value[prefix_len:] + self.parent = split_parent_node + new_len = len(self.token_mem_index_value) + self.node_value_len = new_len + self.node_prefix_total_len = self.parent.node_prefix_total_len + new_len + return split_parent_node + + def add_and_return_new_child(self, token_id_key, token_mem_index_value): + child = TreeNode() + child.token_id_key = token_id_key + child.token_mem_index_value = token_mem_index_value + + self.children[self._compute_key(token_id_key)] = child + child.parent = self + self.total_children_count += 1 + + new_len = len(child.token_mem_index_value) + child.node_value_len = new_len + child.node_prefix_total_len = child.parent.node_prefix_total_len + new_len + return child + + def remove_child(self, child_node: "TreeNode"): + del self.children[self._compute_key(child_node.token_id_key)] + child_node.parent = None + self.total_children_count -= 1 + return + + def update_time(self): + self.time_id = time_gen.generate_time_id() + + def is_leaf(self): + return self.total_children_count == 0 + + +def match(t1: torch.Tensor, t2: torch.Tensor) -> int: + # Ensure same shape for comparison: flatten and get min length + t1_flat = t1.flatten() + t2_flat = t2.flatten() + min_len = min(t1_flat.size(0), t2_flat.size(0)) + + # Compare elements and find first mismatch + diff = t1_flat[:min_len] != t2_flat[:min_len] + mismatch_indices = torch.nonzero(diff) + + if mismatch_indices.numel() == 0: + return min_len # All matched up to min_len + else: + return mismatch_indices[0].item() + + +class PagedRadixCache: + """ + unique_name 主要用于解决单机,多实列部署时的shm冲突 + """ + + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + self.mem_manager = mem_manager + self._key_dtype = torch.int64 + self._value_dtype = torch.int64 + # 预计算page_size相关的常量 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + self.root_node = TreeNode() + self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) + self.root_node.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + self.root_node.ref_counter = 1 # 初始化为 1 保证永远不会被 evict 掉 + + self.evict_tree_set: Set[TreeNode] = SortedSet(key=lambda x: x.get_compare_key()) # 自定义比较器 + self.evict_tree_set.add(self.root_node) + + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) + self.refed_tokens_num.arr[0] = 0 + self.tree_total_tokens_num = SharedArray( + f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 + ) + self.tree_total_tokens_num.arr[0] = 0 + + def _get_page_aligned_key(self, key, value=None): + aligned_len = len(key) + if aligned_len == 0: + return None, None + # page_size > 1时, 需要确保输入的key长度是page_size的整数倍 + if self.page_size > 1: + if aligned_len % self.page_size != 0: + if self._page_size_is_power_of_2: + # 位运算加速 + aligned_len = aligned_len & ~self._page_size_mask + else: + aligned_len = (aligned_len // self.page_size) * self.page_size + return ( + key[:aligned_len] if aligned_len > 0 else None, + value[:aligned_len] if value is not None and aligned_len > 0 else None, + ) + return key, value + + def insert(self, key, value=None): + if value is None: + value = key + + assert len(key) == len(value) # and len(key) >= 1 + key, value = self._get_page_aligned_key(key, value) + if key is None: + return 0 + return self._insert_helper(self.root_node, key, value) + + def _insert_helper(self, node: TreeNode, key, value): + if node.is_leaf(): + self.evict_tree_set.discard(node) + + try: + child, prefix_len = node.find_matched_child(key) + if child is not None: + if prefix_len == len(key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + child.update_time() + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + remaining_key = key[prefix_len:] + remaining_value = value[prefix_len:] + split_parent_node = child.split_node(prefix_len) + new_node = split_parent_node.add_and_return_new_child(remaining_key, remaining_value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + elif prefix_len < len(key) and prefix_len == len(child.token_id_key): + return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:]) + else: + assert False, "can not run to here" + + new_node = node.add_and_return_new_child(key, value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0 + finally: + node.update_time() + if node.is_leaf(): + self.evict_tree_set.add(node) + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + key, _ = self._get_page_aligned_key(key) + if key is None: + return None, 0, None + + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + if tree_node != self.root_node: + if len(ans_value_list) != 0: + value = torch.concat(ans_value_list) + else: + value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + return tree_node, len(value), value + else: + self.dec_node_ref_counter(self.root_node) + return None, 0, None + + def _match_prefix_helper(self, node: TreeNode, key, ans_value_list: list, update_refs=False) -> TreeNode: + if node.is_leaf(): + self.evict_tree_set.discard(node) + + if update_refs: + node.ref_counter += 1 + # from 0 to 1 need update refs token num + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + + try: + if len(key) == 0: + return node + + child, prefix_len = node.find_matched_child(key) + if child is not None: + if prefix_len == len(child.token_id_key): + ans_value_list.append(child.token_mem_index_value) + return self._match_prefix_helper(child, key[prefix_len:], ans_value_list, update_refs=update_refs) + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + ans_value_list.append(split_parent_node.token_mem_index_value) + + if update_refs: + split_parent_node.ref_counter += 1 + # from 0 to 1 need update refs token num + if split_parent_node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value) + + if child.is_leaf(): + self.evict_tree_set.add(child) + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + return split_parent_node + else: + assert False, "error state" + + return node + finally: + node.update_time() + if node.is_leaf(): + self.evict_tree_set.add(node) + + def evict(self, need_remove_tokens, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: TreeNode = self.evict_tree_set.pop(0) + assert node.ref_counter == 0 and node.is_leaf() and node != self.root_node, "error evict tree node state" + num_evicted += len(node.token_mem_index_value) + evict_callback(node.token_mem_index_value) + # update total token num + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + return + + def assert_leafs_is_right(self): + for node in self.evict_tree_set: + if node.is_leaf() and node.ref_counter == 0: + a = node.token_mem_index_value.cuda() + assert (self.mem_manager.mem_state[a] == 1).sum().item() == len(a) + + def clear_tree_nodes(self): + """ + 该函数只在测试时调用 + """ + while True: + node: TreeNode = self.evict_tree_set.pop(0) + if node != self.root_node: + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + else: + break + + self.tree_total_tokens_num.arr[0] = 0 + self.refed_tokens_num.arr[0] = 0 + return + + def dec_node_ref_counter(self, node: TreeNode): + if node is None: + return + # 如果减引用的是叶节点,需要先从 evict_tree_set 中移除 + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) + node.ref_counter -= 1 + node = node.parent + + # 加回。 + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def print_self(self, indent=0): + self._print_helper(self.root_node, indent) + + def _print_helper(self, node: TreeNode, indent): + print( + " " * indent, + f"k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \ + time_id: {node.time_id} prefix_total_len: {node.node_prefix_total_len} \ + node_value_len: {node.node_value_len}", + ) + for _, child in node.children.items(): + self._print_helper(child, indent=indent + 2) + return + + def free_radix_cache_to_get_enough_token( + self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False + ): + assert self.mem_manager is not None + need_pages = 0 + can_use_pages = 0 + if hasattr(self.mem_manager, "can_use_page_size") and self.page_size > 1 and b_seq_len is not None: + + def get_need_page_size(page_size, b_seq_len, b_ready_cache_len=None, is_prefill=False): + need_new_pages = 0 + if is_prefill: + need_tokens_array = b_seq_len - b_ready_cache_len + need_pages_array = (need_tokens_array + page_size - 1) // page_size + need_new_pages = need_pages_array.sum() + else: + mask = (b_seq_len - 1) % page_size == 0 + need_new_pages = mask.sum() + return need_new_pages + + need_pages = get_need_page_size(self.page_size, b_seq_len, b_ready_cache_len, is_prefill) + can_use_pages = self.mem_manager.can_use_page_size + if need_token_num > self.mem_manager.can_use_mem_size or need_pages > can_use_pages: + need_evict_single_token_num = need_token_num - self.mem_manager.can_use_mem_size + need_evict_page_token_num = (need_pages - can_use_pages) * self.page_size + need_evict_token_num = max(need_evict_single_token_num, need_evict_page_token_num) + remaining_tokens = self.get_tree_total_tokens_num() - self.get_refed_tokens_num() + need_evict_token_num = min(need_evict_token_num, remaining_tokens) + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + self.evict(need_evict_token_num, release_mem) + if release_mems: + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + return + + +class _RadixCacheReadOnlyClient: + """ + router 端只读用的客户端,用于从共享内存中读取树结构中的信息,用于进行prompt cache 的调度估计。 + """ + + def __init__(self, unique_name, total_token_num, rank_in_node): + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) + self.tree_total_tokens_num = SharedArray( + f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 + ) + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def get_unrefed_tokens_num(self): + return self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] + + +class RadixCacheReadOnlyClient: + def __init__(self, unique_name, total_token_num, node_world_size, dp_world_size): + self.dp_rank_clients: List[_RadixCacheReadOnlyClient] = [ + _RadixCacheReadOnlyClient(unique_name, total_token_num, rank_in_node) + for rank_in_node in range(0, node_world_size, dp_world_size) + ] + + def get_refed_tokens_num(self, dp_rank_in_node): + return self.dp_rank_clients[dp_rank_in_node].get_refed_tokens_num() + + def get_tree_total_tokens_num(self, dp_rank_in_node): + return self.dp_rank_clients[dp_rank_in_node].get_tree_total_tokens_num() + + def get_unrefed_tokens_num(self, dp_rank_in_node): + return self.dp_rank_clients[dp_rank_in_node].get_unrefed_tokens_num() diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 65ec4354b..a60d0a942 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -333,7 +333,9 @@ def _print_helper(self, node: TreeNode, indent): self._print_helper(child, indent=indent + 2) return - def free_radix_cache_to_get_enough_token(self, need_token_num): + def free_radix_cache_to_get_enough_token( + self, need_token_num=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=False + ): assert self.mem_manager is not None if need_token_num > self.mem_manager.can_use_mem_size: need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index fd75afdbf..5284d7fa9 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -10,6 +10,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache +from lightllm.server.router.dynamic_prompt.paged_radix_cache import PagedRadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -139,8 +140,9 @@ def init_model(self, kvargs): self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) + radix_cache_class = PagedRadixCache if "page_size_variable" in self.mode else RadixCache self.radix_cache = ( - RadixCache( + radix_cache_class( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 5059ad27f..448c0d987 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -77,7 +77,9 @@ def padded_prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token( + input_ids.shape[0] - padded_req_num, b_seq_len, b_ready_cache_len, True + ) mem_indexes = g_infer_context.req_manager.mem_manager.alloc( input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len, True ) @@ -164,7 +166,7 @@ def padded_prepare_decode_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num, b_seq_len) mem_indexes = g_infer_context.req_manager.mem_manager.alloc( b_seq_len.shape[0] - padded_req_num, b_req_idx, b_seq_len ) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 1c81ebfcd..e5e871d83 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -55,7 +55,9 @@ def prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token( + input_ids.shape[0], b_seq_len, b_ready_cache_len, True + ) mem_indexes = g_infer_context.req_manager.mem_manager.alloc( input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True ) @@ -113,7 +115,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0]) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0], b_seq_len) mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0], b_req_idx, b_seq_len) g_infer_state_lock.release() diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index b78784d82..9d73b2a1d 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -149,6 +149,15 @@ def get_kv_quant_calibration_inference_count(): return int(os.getenv("LIGHTLLM_KV_QUANT_CALIBRARTION_INFERENCE_COUNT", 4000)) +@lru_cache(maxsize=None) +def get_page_size(): + try: + args = get_env_start_args() + return int(os.getenv("PAGE_SIZE", 4)) if "page_size_variable" in args.mode else 1 + except: + return 1 + + g_model_init_done = False From 1a6af0a4aa5c9f46e211f56bdc65e7848337cd8e Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Fri, 1 Aug 2025 18:21:07 +0800 Subject: [PATCH 3/3] feat: support page size variable for flashinfer --- lightllm/models/llama/flashinfer_struct.py | 89 +++++--- .../layer_infer/transformer_layer_infer.py | 52 ++++- lightllm/utils/envs_utils.py | 2 +- .../test_context_flashattention_nopad.py | 5 +- ..._context_flashattention_nopad_fa3_paged.py | 163 +++++++++++++ ...t_flashattention_nopad_flashinfer_paged.py | 214 ++++++++++++++++++ .../test_token_attention_nopad_fa3_paged.py | 186 +++++++++++++++ ..._token_attention_nopad_flashinfer_paged.py | 169 ++++++++++++++ 8 files changed, 843 insertions(+), 37 deletions(-) create mode 100644 unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py create mode 100644 unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py create mode 100644 unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py create mode 100644 unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py index cea95a203..3b9a378c4 100644 --- a/lightllm/models/llama/flashinfer_struct.py +++ b/lightllm/models/llama/flashinfer_struct.py @@ -3,16 +3,21 @@ import numpy as np import torch.distributed as dist from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index +def cdiv(a, b): + return (a + b - 1) // b + + class LlamaFlashInferStateInfo(LlamaInferStateInfo): def __init__(self): super().__init__() self.prefill_wrapper = None self.decode_wrapper = None self.flashinfer_extra_state = None + self.page_size = get_page_size() def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) @@ -22,29 +27,41 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: - self.kv_last_page_len_buffer = torch.full( - (self.batch_size,), 1, dtype=torch.int32, device=input_ids.device - ) + self.kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device) + length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ - : self.batch_size * self.flashinfer_extra_state.max_seq_length + : self.batch_size * length ] else: self.kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, + self.batch_size * length, dtype=torch.int32, device=input_ids.device, ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, - self.kv_indices, - ) self.kv_starts = self.b1_cu_kv_seq_len.int() + if "page_size_variable" in model.mode: + b_page_len = cdiv(self.b_seq_len, self.page_size) + self.kv_starts[1:] = b_page_len.cumsum(0) + self.kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size + repack_kv_index( + self.req_manager.req_to_page_indexs, + self.b_req_idx, + b_page_len, + self.kv_starts[:-1], + cdiv(self.max_kv_seq_len, self.page_size), + self.kv_indices, + ) + else: + repack_kv_index( + self.req_manager.req_to_token_indexs, + self.b_req_idx, + self.b_seq_len, + self.b_start_loc, + self.max_kv_seq_len, + self.kv_indices, + ) if self.decode_wrapper is None: self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( self.flashinfer_extra_state.workspace_buffer, @@ -53,16 +70,16 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): use_tensor_cores=True, paged_kv_indptr_buffer=self.kv_starts, paged_kv_indices_buffer=self.kv_indices, - paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer, + paged_kv_last_page_len_buffer=self.kv_last_page_len, ) self.decode_wrapper.plan( self.kv_starts, self.kv_indices, - self.kv_last_page_len_buffer, + self.kv_last_page_len, self.flashinfer_extra_state.tp_q_head_num, self.flashinfer_extra_state.tp_kv_head_num, self.flashinfer_extra_state.head_dim, - 1, + self.page_size, q_data_type=self.flashinfer_extra_state.q_data_type, kv_data_type=self.flashinfer_extra_state.kv_data_type, non_blocking=True, @@ -72,19 +89,33 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): q_starts = self.b1_cu_q_seq_len.int() kv_starts = self.b1_cu_kv_seq_len.int() kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device) + length = cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, + self.batch_size * length, dtype=torch.int32, device=input_ids.device, ) - repack_kv_index( - self.req_manager.req_to_token_indexs, - self.b_req_idx, - self.b_seq_len, - kv_starts, - self.max_len_in_batch, - kv_indices, - ) + if "page_size_variable" in model.mode: + b_page_len = cdiv(self.b_seq_len, self.page_size) + kv_starts[1:] = b_page_len.cumsum(0) + kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size + repack_kv_index( + self.req_manager.req_to_page_indexs, + self.b_req_idx, + b_page_len, + kv_starts[:-1], + cdiv(self.max_kv_seq_len, self.page_size), + kv_indices, + ) + else: + repack_kv_index( + self.req_manager.req_to_token_indexs, + self.b_req_idx, + self.b_seq_len, + kv_starts[:-1], + self.max_kv_seq_len, + kv_indices, + ) self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( self.flashinfer_extra_state.workspace_buffer, qo_indptr_buf=q_starts, @@ -100,7 +131,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.flashinfer_extra_state.tp_q_head_num, self.flashinfer_extra_state.tp_kv_head_num, self.flashinfer_extra_state.head_dim, - 1, + self.page_size, causal=True, pos_encoding_mode="NONE", logits_soft_cap=0.0, @@ -115,11 +146,11 @@ def copy_for_cuda_graph(self, new_infer_state): self.decode_wrapper.plan( new_infer_state.kv_starts, new_infer_state.kv_indices, - new_infer_state.kv_last_page_len_buffer, + new_infer_state.kv_last_page_len, new_infer_state.flashinfer_extra_state.tp_q_head_num, new_infer_state.flashinfer_extra_state.tp_kv_head_num, new_infer_state.flashinfer_extra_state.head_dim, - 1, + self.page_size, q_data_type=new_infer_state.flashinfer_extra_state.q_data_type, kv_data_type=new_infer_state.flashinfer_extra_state.kv_data_type, non_blocking=True, diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index bf1d64f62..850a8c3f9 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -107,9 +107,16 @@ def _bind_attention(self): raise Exception(f"Unsupported mode for fa3 backend: {self.mode}") return elif get_env_start_args().enable_flashinfer_prefill: - self._context_attention_kernel = partial( - LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self - ) + if "page_size_variable" in self.mode: + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._paged_context_attention_flashinfer_kernel, self + ) + elif not self.mode: + self._context_attention_kernel = partial( + LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self + ) + else: + raise Exception(f"Unsupported mode for flashinfer backend: {self.mode}") else: self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) if "ppl_int8kv" in self.mode: @@ -174,6 +181,12 @@ def _bind_attention(self): self._copy_kv_to_mem_cache = partial( LlamaTransformerLayerInfer._copy_kv_to_mem_cache_with_calibration, self ) + elif "page_size_variable" in self.mode: + assert get_env_start_args().enable_flashinfer_prefill and get_env_start_args().enable_flashinfer_decode + self._token_attention_kernel = partial( + LlamaTransformerLayerInfer._paged_token_decode_attention_flashinfer, self + ) + self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) elif not self.mode: if get_env_start_args().enable_flashinfer_decode: self._token_attention_kernel = partial( @@ -274,6 +287,21 @@ def _context_attention_flashinfer_kernel( ) return o_tensor + def _paged_context_attention_flashinfer_kernel( + self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + page_size = get_page_size() + kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( + -1, page_size, 2 * self.tp_k_head_num_, self.head_dim_ + ) + infer_state.prefill_wrapper.run( + q.view(q.shape[0], -1, self.head_dim_), + (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), + out=o_tensor.view(q.shape[0], -1, self.head_dim_), + ) + return o_tensor + def _context_attention_kernel( self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None ) -> torch.Tensor: @@ -587,6 +615,24 @@ def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStat ) return o_tensor + def _paged_token_decode_attention_flashinfer( + self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None + ): + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) + + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + page_size = get_page_size() + kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( + -1, page_size, 2 * self.tp_k_head_num_, self.head_dim_ + ) + infer_state.decode_wrapper.run( + q.view(calcu_shape1), + (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), + out=o_tensor.view(calcu_shape1), + ) + return o_tensor + def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): total_token_num = infer_state.total_token_num batch_size = infer_state.batch_size diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 9d73b2a1d..1cb030cb1 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -153,7 +153,7 @@ def get_kv_quant_calibration_inference_count(): def get_page_size(): try: args = get_env_start_args() - return int(os.getenv("PAGE_SIZE", 4)) if "page_size_variable" in args.mode else 1 + return int(os.getenv("PAGE_SIZE", 64)) if "page_size_variable" in args.mode else 1 except: return 1 diff --git a/unit_tests/models/llama/test_context_flashattention_nopad.py b/unit_tests/models/llama/test_context_flashattention_nopad.py index f24ab619b..94e61cfda 100644 --- a/unit_tests/models/llama/test_context_flashattention_nopad.py +++ b/unit_tests/models/llama/test_context_flashattention_nopad.py @@ -10,7 +10,6 @@ context_attention_fwd_no_prompt_cache, ) from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager logger = init_logger(__name__) @@ -56,8 +55,6 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): infer_state.batch_size = Z infer_state.max_len_in_batch = N_CTX infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) - infer_state.req_manager.req_to_token_indexs = req_to_token_indexs infer_state.b_req_idx = b_req_idx infer_state.b_seq_len = b_seq_len infer_state.b_ready_cache_len = b_ready_cache_len @@ -73,7 +70,7 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): infer_state.b_seq_len, infer_state.b_ready_cache_len, infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, + req_to_token_indexs, ) batch_size = Z diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py b/unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py new file mode 100644 index 000000000..e7702f084 --- /dev/null +++ b/unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py @@ -0,0 +1,163 @@ +import torch +import time +import pytest +import triton as tl +import numpy as np +import torch.nn.functional as F +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd, +) +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): + device = kv_buffer.device + B = seq_lens.size(0) + min_fp8 = torch.finfo(torch.float8_e4m3fn).min + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + _, S_max, H, D = kv_buffer.shape + seq_range = torch.arange(S_max, device=device)[None, :] + valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) + masked = kv_buffer * valid_mask + max_per_bh = masked.abs().amax(dim=(1, 3)) # [B, H] + scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)).to(torch.float32) + scales_exp = scales.view(B, 1, H, 1) + q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) + return q, scales + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + kv = torch.randn((Z * N_CTX // page_size, page_size, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) + rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") + b_seq_len += rand_num + b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") + if N_CTX > 1: + b_ready_cache_len = torch.randint_like(b_seq_len, high=(N_CTX - 1) // 2, dtype=torch.int32, device="cuda") + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + q_lens = b_seq_len - b_ready_cache_len + q_start_loc = q_lens.cumsum(0) - q_lens + + q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_ready_cache_len = b_ready_cache_len + infer_state.b_start_loc = q_start_loc + + context_attention_fwd( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32, device="cuda") + page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) + + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) + kv_starts = torch.zeros_like(q_starts) + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + + k_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :] + v_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :] + o1 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache.reshape(-1, 1, kv_heads, head_dim), + v_cache=v_cache.reshape(-1, 1, kv_heads, head_dim), + page_table=page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=N_CTX, + causal=True, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2) + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(f"cos_sim1: {cos_sim1}") + assert cos_sim1.item() == 1 + + k_cache_paged = k_cache.reshape(-1, page_size, kv_heads, head_dim) + v_cache_paged = v_cache.reshape(-1, page_size, kv_heads, head_dim) + + page_table_paged = torch.empty((batch_size, N_CTX // page_size), dtype=torch.int32, device="cuda") + page_table_paged.copy_(req_to_page_indexs[b_req_idx, : N_CTX // page_size]) + + o2 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache_paged, + v_cache=v_cache_paged, + page_table=page_table_paged, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=N_CTX, + causal=True, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + assert torch.allclose(o1, o2, atol=1e-2, rtol=1e-2) + cos_sim2 = F.cosine_similarity(o1, o2).mean() + print(f"cos_sim2: {cos_sim2}") + assert cos_sim2.item() == 1 + + +if __name__ == "__main__": + test_context_attention_fwd_fa3_fp8(32, 16384, 32, 4, 128) diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py b/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py new file mode 100644 index 000000000..763a80015 --- /dev/null +++ b/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py @@ -0,0 +1,214 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +import flashinfer +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd, + context_attention_fwd_no_prompt_cache, +) +from lightllm.models.llama.infer_struct import LlamaInferStateInfo + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + kv = torch.randn((Z * N_CTX // page_size, page_size, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.randint_like(b_seq_len, high=N_CTX - 1, dtype=torch.int32, device="cuda") + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + q_lens = b_seq_len - b_ready_cache_len + q_start_loc = q_lens.cumsum(0) - q_lens + + q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_ready_cache_len = b_ready_cache_len + infer_state.b_start_loc = q_start_loc + + context_attention_fwd( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) + + num_pages_per_seq = torch.ceil(b_seq_len.float() / page_size).int() + kv_starts = torch.zeros((Z + 1,)).int().cuda() + kv_starts[1:] = torch.cumsum(num_pages_per_seq, dim=0) + + q_indptr = q_starts.int() + kv_indptr = kv_starts.int() + + total_pages = num_pages_per_seq.sum().item() + kv_indices = torch.zeros(total_pages, dtype=torch.int32, device="cuda") + + # 设置kv_indices + b_start_loc = num_pages_per_seq.cumsum(0) - num_pages_per_seq + for req, sl, start in zip(b_req_idx, num_pages_per_seq, b_start_loc): + kv_indices[start : start + sl] = req_to_page_indexs[req][:sl] + + kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) + wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, + qo_indptr_buf=q_indptr, + paged_kv_indptr_buf=kv_indptr, + paged_kv_indices_buf=kv_indices, + paged_kv_last_page_len_buf=kv_last_page_len_buffer, + ) + + # 设置kv_last_page_len + kv_last_page_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + for i in range(Z): + seq_len = b_seq_len[i].item() + remainder = seq_len % page_size + kv_last_page_len[i] = remainder if remainder > 0 else page_size + + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + q_heads, + kv_heads, + head_dim, + page_size, + causal=True, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + q_data_type=q.dtype, + kv_data_type=kv.dtype, + ) + k_cache = kv[:, :, :KV_HEADS, :] + v_cache = kv[:, :, KV_HEADS:, :] + wrapper.run(q, (k_cache, v_cache), out=o1, return_lse=False) + cos_sim1 = F.cosine_similarity(o, o1).mean() + assert cos_sim1 == 1.0 + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd_no_prompt_cache(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + q = torch.randn((Z * N_CTX, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + k = torch.randn((Z * N_CTX, KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + v = torch.randn((Z * N_CTX, KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_start_loc = b_seq_len.cumsum(0) - b_seq_len + + o = torch.zeros((Z * N_CTX, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z * N_CTX, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + context_attention_fwd_no_prompt_cache( + q, + k, + v, + o, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len, dim=0) + kv_starts = torch.zeros_like(q_starts) + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + q_indptr = q_starts.int() + kv_indptr = kv_starts.int() + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, + ) + wrapper.plan( + qo_indptr=q_indptr, + kv_indptr=kv_indptr, + num_qo_heads=q_heads, + num_kv_heads=kv_heads, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + q_data_type=dtype, + causal=True, + ) + wrapper.run(q, k, v, out=o1, return_lse=False) + + # assert torch.allclose(o, o1, atol=1e-2, rtol=0) + cos_sim1 = F.cosine_similarity(o, o1).mean() + assert cos_sim1 == 1.0 + + +if __name__ == "__main__": + test_context_attention_fwd(32, 16384, 32, 4, 128) # 16384 is divisible by 4 diff --git a/unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py b/unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py new file mode 100644 index 000000000..1de2fbc34 --- /dev/null +++ b/unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py @@ -0,0 +1,186 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd +from lightllm.utils.sgl_utils import flash_attn_with_kvcache + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): + device = kv_buffer.device + B = seq_lens.size(0) + min_fp8 = torch.finfo(torch.float8_e4m3fn).min + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + _, S_max, H, D = kv_buffer.shape + seq_range = torch.arange(S_max, device=device)[None, :] + valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) + masked = kv_buffer * valid_mask + max_per_bh = masked.float().abs().amax(dim=(1, 3)) # [B, H] + scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)) + scales_exp = scales.view(B, 1, H, 1) + q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) + return q, scales + + +def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd + + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, q_h, h_dim) + + att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() + + token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + v, + o, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_token_attention_nopad_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) + rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") + b_seq_len += rand_num + b_start_loc = b_seq_len.cumsum(0) - b_seq_len + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + + o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + ref_token_attention_nopad( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + Q_HEADS, + HEAD_DIM, + infer_state, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + kv_starts = torch.zeros((Z + 1,)).int().cuda() + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + q_starts = torch.arange(0, Z + 1).int().cuda() + page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32).to(0) + page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) + + k_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :].contiguous() + v_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :].contiguous() + o1 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache.view(-1, 1, kv_heads, head_dim), + v_cache=v_cache.view(-1, 1, kv_heads, head_dim), + page_table=page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=1, + causal=False, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + # assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(cos_sim1) + assert cos_sim1 == 1 + + k_cache_paged = k_cache.reshape(-1, page_size, kv_heads, head_dim) + v_cache_paged = v_cache.reshape(-1, page_size, kv_heads, head_dim) + + page_table_paged = torch.empty((batch_size, N_CTX // page_size), dtype=torch.int32, device="cuda") + page_table_paged.copy_(req_to_page_indexs[b_req_idx, : N_CTX // page_size]) + + o2 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache_paged, + v_cache=v_cache_paged, + page_table=page_table_paged, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=1, + causal=False, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + assert torch.allclose(o1, o2, atol=1e-2, rtol=1e-2) + cos_sim2 = F.cosine_similarity(o1, o2).mean() + print(cos_sim2) + assert cos_sim2.item() == 1 + + +if __name__ == "__main__": + test_token_attention_nopad_fa3_fp8(16, 16384, 28, 4, 128) diff --git a/unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py b/unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py new file mode 100644 index 000000000..9bb97be99 --- /dev/null +++ b/unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py @@ -0,0 +1,169 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +import flashinfer +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd + + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, q_h, h_dim) + + att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() + + token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + v, + o, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_token_attention_nopad(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + kv = torch.randn((Z * N_CTX // page_size, page_size, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_start_loc = torch.arange(Z).cuda().int() * N_CTX + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + + o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + ref_token_attention_nopad( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + Q_HEADS, + HEAD_DIM, + infer_state, + req_to_token_indexs, + ) + # gqa_decode_attention_fwd( + # q, + # kv[:,:KV_HEADS,:], + # kv[:,KV_HEADS:,:], + # o, + # req_to_token_indexs, + # infer_state.b_req_idx, + # infer_state.b_seq_len, + # ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + + num_pages_per_seq = torch.ceil(b_seq_len.float() / page_size).int() + kv_indptr = torch.zeros(Z + 1, dtype=torch.int32, device="cuda") + kv_indptr[1:] = torch.cumsum(num_pages_per_seq, dim=0) + + # Fill the paged KV data indices + total_pages = kv_indptr[-1].item() + kv_indices = torch.zeros(total_pages, dtype=torch.int32, device="cuda") + b_start_loc = num_pages_per_seq.cumsum(0) - num_pages_per_seq + for req, sl, start in zip(b_req_idx, num_pages_per_seq, b_start_loc): + kv_indices[start : start + sl] = req_to_page_indexs[req][:sl] + + # Calculate last page lengths + kv_last_page_len = torch.zeros(Z, dtype=torch.int32, device="cuda") + for i in range(Z): + seq_len = b_seq_len[i].item() + remainder = seq_len % page_size + kv_last_page_len[i] = remainder if remainder > 0 else page_size + + kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=True, + paged_kv_indptr_buffer=kv_indptr, + paged_kv_indices_buffer=kv_indices, + paged_kv_last_page_len_buffer=kv_last_page_len_buffer, + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + q_heads, + kv_heads, + head_dim, + page_size, + q_data_type=dtype, + non_blocking=True, + ) + wrapper.run(q, (kv[:, :, :KV_HEADS, :], kv[:, :, KV_HEADS:, :]), out=o1, return_lse=False) + cos_sim = F.cosine_similarity(o, o1).mean() + assert cos_sim == 1.0 + + +if __name__ == "__main__": + test_token_attention_nopad(32, 16384, 32, 4, 128)