Skip to content

Fp8 deepseek #975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 3 additions & 63 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.common.mem_manager import MemoryManager
from lightllm.common.req_manager import ReqManager
from lightllm.common.infer_utils import init_req_to_token_indexes
from lightllm.common.build_utils import repair_config
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
from lightllm.common.basemodel.cuda_graph import CudaGraph
from lightllm.common.quantization import Quantcfg
Expand Down Expand Up @@ -330,14 +328,6 @@ def _prefill(
model_input: ModelInput,
):
infer_state = self._create_inferstate(model_input)
init_req_to_token_indexes(
self.req_manager.req_to_token_indexs,
model_input.b_req_idx,
model_input.b_seq_len,
infer_state.b_ready_cache_len,
model_input.max_len_in_batch,
infer_state.mem_index,
)

infer_state.init_some_extra_state(self, model_input.input_ids)
return self._context_forward(model_input.input_ids, infer_state)
Expand All @@ -350,12 +340,6 @@ def _decode(
find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size)
padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size)
infer_state = self._create_inferstate(padded_model_input)
copy_kv_index_to_req(
self.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_seq_len,
infer_state.mem_index,
)
infer_state.init_some_extra_state(self, padded_model_input.input_ids)

if self.graph.need_capture(find_graph_batch_size):
Expand All @@ -371,12 +355,6 @@ def _decode(
)
else:
infer_state = self._create_inferstate(model_input)
copy_kv_index_to_req(
self.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_seq_len,
infer_state.mem_index,
)
infer_state.init_some_extra_state(self, model_input.input_ids)
model_output = self._token_forward(model_input.input_ids, infer_state)

Expand Down Expand Up @@ -458,25 +436,9 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids

infer_state0 = self._create_inferstate(model_input0, 0)
init_req_to_token_indexes(
self.req_manager.req_to_token_indexs,
model_input0.b_req_idx,
model_input0.b_seq_len,
infer_state0.b_ready_cache_len,
model_input0.max_len_in_batch,
infer_state0.mem_index,
)
infer_state0.init_some_extra_state(self, input_ids0)

infer_state1 = self._create_inferstate(model_input1, 1)
init_req_to_token_indexes(
self.req_manager.req_to_token_indexs,
model_input1.b_req_idx,
model_input1.b_seq_len,
infer_state1.b_ready_cache_len,
model_input1.max_len_in_batch,
infer_state1.mem_index,
)
infer_state1.init_some_extra_state(self, input_ids1)

model_output0, model_output1 = self._overlap_tpsp_context_forward(
Expand All @@ -502,20 +464,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size)
padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size)
infer_state0 = self._create_inferstate(padded_model_input0, 0)
copy_kv_index_to_req(
self.req_manager.req_to_token_indexs,
infer_state0.b_req_idx,
infer_state0.b_seq_len,
infer_state0.mem_index,
)
infer_state0.init_some_extra_state(self, padded_model_input0.input_ids)
infer_state1 = self._create_inferstate(padded_model_input1, 1)
copy_kv_index_to_req(
self.req_manager.req_to_token_indexs,
infer_state1.b_req_idx,
infer_state1.b_seq_len,
infer_state1.mem_index,
)
infer_state1.init_some_extra_state(self, padded_model_input1.input_ids)

if self.graph.need_capture(find_graph_batch_size):
Expand All @@ -540,20 +490,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
model_output1 = self._create_unpad_decode_model_output(model_output1, origin_batch_size=origin_batch_size)
else:
infer_state0 = self._create_inferstate(model_input0, 0)
copy_kv_index_to_req(
self.req_manager.req_to_token_indexs,
infer_state0.b_req_idx,
infer_state0.b_seq_len,
infer_state0.mem_index,
)
infer_state0.init_some_extra_state(self, model_input0.input_ids)
infer_state1 = self._create_inferstate(model_input1, 1)
copy_kv_index_to_req(
self.req_manager.req_to_token_indexs,
infer_state1.b_req_idx,
infer_state1.b_seq_len,
infer_state1.mem_index,
)
infer_state1.init_some_extra_state(self, model_input1.input_ids)

model_output0, model_output1 = self._overlap_tpsp_token_forward(
Expand Down Expand Up @@ -654,10 +592,12 @@ def _check_max_len_infer(self):
logger.info("begin check max_len infer")
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda")
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
b_seq_len[:] = self.batch_max_tokens
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
mem_indexes = self.mem_manager.alloc(
len(dummy_input_ids), b_req_idx, b_seq_len, b_ready_cache_len, True
).cuda()
total_token_num = self.batch_max_tokens
model_input = ModelInput(
batch_size=1,
Expand Down
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,12 @@ def warmup(self, model):
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
b_req_idx = torch.tensor(
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
)
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
b_seq_len.fill_(seq_len)
mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len, None, False).cuda()

model_input = ModelInput(
batch_size=batch_size,
Expand Down Expand Up @@ -250,12 +250,12 @@ def warmup_overlap(self, model):
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
b_req_idx = torch.tensor(
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
)
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
b_seq_len.fill_(seq_len)
mem_indexes = model.mem_manager.alloc(len(input_ids), b_req_idx, b_seq_len, None, False).cuda()

micro_batch = ModelInput(
is_prefill=False,
Expand Down
4 changes: 1 addition & 3 deletions lightllm/common/infer_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
def init_req_to_token_indexes(
req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, max_len_in_batch, alloc_mem_index
):
def init_req_to_token_indexes(req_to_token_indexs, b_req_idx, b_seq_len, b_ready_cache_len, alloc_mem_index):
start_index = 0
b_seq_len_numpy = b_seq_len.cpu().numpy()
b_ready_cache_len_numpy = b_ready_cache_len.cpu().numpy()
Expand Down
23 changes: 22 additions & 1 deletion lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
from lightllm.distributed.pynccl import PyNcclCommunicator
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.common.infer_utils import init_req_to_token_indexes
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req

logger = init_logger(__name__)

Expand Down Expand Up @@ -52,6 +54,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
layer_num,
)
self.HOLD_TOKEN_MEMINDEX = self.size
self.req_to_token_indexs = None

def get_cell_size(self):
return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)
Expand Down Expand Up @@ -243,7 +246,7 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to
def _free_buffers(self):
self.kv_buffer = None

def alloc(self, need_size) -> torch.Tensor:
def alloc(self, need_size, b_req_idx=None, b_seq_len=None, b_ready_cache_len=None, is_prefill=True) -> torch.Tensor:
if need_size > self.mark_end - self.mark_start:
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
assert False, "error alloc state"
Expand All @@ -255,6 +258,24 @@ def alloc(self, need_size) -> torch.Tensor:

self.can_use_mem_size -= need_size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)

if self.req_to_token_indexs is not None:
assert b_req_idx is not None and b_seq_len is not None, "b_req_idx and b_seq_len must be provided"
if is_prefill:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When is_prefill is True, b_ready_cache_len is passed to init_req_to_token_indexes. However, the alloc function signature allows b_ready_cache_len to be None, which would cause a runtime error inside init_req_to_token_indexes as it calls .cpu().numpy() on it.

Please add an assertion to ensure b_ready_cache_len is not None when is_prefill is True. For example:

if is_prefill:
    assert b_ready_cache_len is not None, "b_ready_cache_len must be provided for prefill"
    init_req_to_token_indexes(
        ...
    )

init_req_to_token_indexes(
self.req_to_token_indexs,
b_req_idx,
b_seq_len,
b_ready_cache_len,
ans,
)
else:
copy_kv_index_to_req(
self.req_to_token_indexs,
b_req_idx,
b_seq_len,
ans,
)
return ans

def free(self, free_index: Union[torch.Tensor, List[int]]):
Expand Down
1 change: 1 addition & 0 deletions lightllm/common/req_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
self.req_to_token_indexs = torch.zeros(
(max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda"
)
mem_manager.req_to_token_indexs = self.req_to_token_indexs
self.mem_manager = mem_manager
self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num)
self.max_request_num = max_request_num
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def padded_prepare_prefill_inputs(
g_infer_state_lock.acquire()
if g_infer_context.radix_cache is not None:
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num)
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda()
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, b_ready_cache_len, True
).cuda()
g_infer_state_lock.release()

if padded_req_num > 0:
Expand Down Expand Up @@ -155,7 +157,9 @@ def padded_prepare_decode_inputs(
g_infer_state_lock.acquire()
if g_infer_context.radix_cache is not None:
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num)
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda()
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
input_ids.shape[0] - padded_req_num, b_req_idx, b_seq_len, None, False
).cuda()
g_infer_state_lock.release()

if padded_req_num > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def prepare_prefill_inputs(
g_infer_state_lock.acquire()
if g_infer_context.radix_cache is not None:
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0])
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]).cuda()
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
input_ids.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True
).cuda()
g_infer_state_lock.release()

model_input = ModelInput(
Expand Down Expand Up @@ -105,7 +107,9 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
g_infer_state_lock.acquire()
if g_infer_context.radix_cache is not None:
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0])
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]).cuda()
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(
input_ids.shape[0], b_req_idx, b_seq_len, None, False
).cuda()
g_infer_state_lock.release()

model_input = ModelInput(
Expand Down
8 changes: 6 additions & 2 deletions test/benchmark/static_inference/model_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ def run_forward_once(
b_seq_len[i] = input_len

total_token_num = batch_size * input_len
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]).cuda()
mem_indexes = model_part.req_manager.mem_manager.alloc(
test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True
).cuda()

rank_id = model_kvargs["rank_id"]

Expand Down Expand Up @@ -303,7 +305,9 @@ def run_forward_once(
step_start = time.time()
total_token_num += batch_size
b_seq_len += 1
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]).cuda()
mem_indexes = model_part.req_manager.mem_manager.alloc(
predict_ids.shape[0], b_req_idx, b_seq_len, None, False
).cuda()
max_len_in_batch = input_len + i + 1
logits = decode_fn(
model_part,
Expand Down
8 changes: 6 additions & 2 deletions test/benchmark/static_inference/model_infer_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
b_seq_len[i] = input_len

total_token_num = input_len * batch_size
mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda()
mem_indexes = main_model.req_manager.mem_manager.alloc(
test_data.shape[0], b_req_idx, b_seq_len, b_ready_cache_len, True
).cuda()
# Main model Prefill
model_input = ModelInput(
batch_size=batch_size,
Expand Down Expand Up @@ -193,7 +195,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_

nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda")
nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda")
mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda()
mem_indexes = main_model.req_manager.mem_manager.alloc(
batch_size * (len(draft_models) + 1), nopad_b_seq_idx, nopad_b_seq_len, None, False
).cuda()

model_input = ModelInput(
batch_size=batch_size * (len(draft_models) + 1),
Expand Down