diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index b0fcbbea2..942411037 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -426,7 +426,9 @@ async def process_single_prompt(prompt: Union[str, List[int]], prompt_index: int prompt, individual_sampling_params, multimodal_params, request=raw_request ) - return await _collect_generation_results(generator, request, prompt_str, prompt_index) + return await _collect_generation_results( + generator, request, prompt_str, prompt_index, individual_sampling_params + ) tasks = [asyncio.create_task(process_single_prompt(prompt, i)) for i, prompt in enumerate(prompts)] @@ -485,7 +487,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) -async def _collect_generation_results(generator, request: CompletionRequest, prompt: str, prompt_index: int): +async def _collect_generation_results( + generator, request: CompletionRequest, prompt: str, prompt_index: int, sampling_params: SamplingParams +): final_output = [] count_output_tokens = 0 finish_reason = None @@ -516,9 +520,30 @@ async def _collect_generation_results(generator, request: CompletionRequest, pro finish_reason = finish_status.get_finish_reason() prompt_tokens = metadata["prompt_tokens"] + # 处理停止序列剔除 + final_text = "".join(final_output) + if finish_reason == "stop" and sampling_params.stop_sequences.size > 0: + stop_strings = sampling_params.stop_sequences.to_string() + valid_stop_strings = [s for s in stop_strings if s] + if valid_stop_strings: + max_stop_len = len(valid_stop_strings[0]) + search_len = min(len(final_text), max_stop_len + 20) # 搜索长度为最长停止序列长度加20 + tail_text = final_text[-search_len:] if search_len > 0 else final_text + tail_start_pos = len(final_text) - search_len + earliest_stop_index = len(final_text) + for stop_str in valid_stop_strings: + stop_index = tail_text.find(stop_str) + if stop_index != -1: + actual_stop_index = tail_start_pos + stop_index + if actual_stop_index < earliest_stop_index: + earliest_stop_index = actual_stop_index + + if earliest_stop_index < len(final_text): + final_text = final_text[:earliest_stop_index] + return { "index": prompt_index, - "text": "".join(final_output), + "text": final_text, "finish_reason": finish_reason, "prompt_tokens": prompt_tokens, "completion_tokens": count_output_tokens, diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 634dd208d..3e2271e90 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -10,6 +10,7 @@ # 从环境变量获取最大长度限制 STOP_SEQUENCE_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_MAX_LENGTH", 256)) +STOP_SEQUENCE_STR_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_STR_MAX_LENGTH", 256)) ALLOWED_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH", 256)) MAX_STOP_SEQUENCES = int(os.getenv("LIGHTLLM_MAX_STOP_SEQUENCES", 10)) REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048)) @@ -22,17 +23,27 @@ class StopSequence(ctypes.Structure): _fields_ = [ ("sequence", ctypes.c_int * STOP_SEQUENCE_MAX_LENGTH), ("size", ctypes.c_int), + ("sequence_str", ctypes.c_char * STOP_SEQUENCE_STR_MAX_LENGTH), + ("sequence_str_len", ctypes.c_int), ] - def initialize(self, sequence: List[int]): + def initialize(self, sequence: List[int], sequence_str: str = ""): self.size = len(sequence) assert self.size <= STOP_SEQUENCE_MAX_LENGTH, "stop token length too long." assert all(isinstance(e, int) for e in sequence), "all must be int" self.sequence[: self.size] = sequence[:] + sequence_str_bytes = sequence_str.encode("utf-8") + assert len(sequence_str_bytes) < STOP_SEQUENCE_STR_MAX_LENGTH, "stop sequence string too long." + self.sequence_str = sequence_str_bytes + self.sequence_str_len = len(sequence_str_bytes) + def to_list(self): return list(self.sequence[0 : self.size]) + def to_string(self): + return bytes(self.sequence_str[0 : self.sequence_str_len]).decode("utf-8") + class StopSequenceGroups(ctypes.Structure): _pack_ = 4 @@ -45,8 +56,10 @@ def initialize(self, stop_sequences: Union[str, List], tokenizer): groups: List[List[int]] = self.stop_sentences_to_token_ids(stop_sequences, tokenizer) self.size = len(groups) assert self.size <= MAX_STOP_SEQUENCES, "Too many stop sequence groups." + if isinstance(stop_sequences, str): + stop_sequences = [stop_sequences] for group_idx in range(self.size): - self.groups[group_idx].initialize(groups[group_idx]) + self.groups[group_idx].initialize(groups[group_idx], stop_sequences[group_idx]) def stop_sentences_to_token_ids(self, stop_sequences, tokenizer): if stop_sequences is None: @@ -75,6 +88,10 @@ def _stop_str_to_token_ids(self, stop_str: str, tokenizer): def to_list(self): return [self.groups[i].to_list() for i in range(self.size)] + def to_string(self): + # 降序匹配,在出现"\n\n"和"\n"情况时,优先匹配“\n\n” + return sorted([self.groups[i].to_string() for i in range(self.size)], key=len, reverse=True) + class RegularConstraint(ctypes.Structure): _pack_ = 4 diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 0b930c72d..ef001aeb6 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -318,6 +318,7 @@ def _init_all_state(self): g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self) self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list() + self.stop_sequences_str = self.sampling_param.shm_param.stop_sequences.to_string() # token healing mode 才被使用的管理对象 if self.shm_req.prefix_token_ids.size != 0: self.prefix_token_ids = self.shm_req.prefix_token_ids.get_token_ids() @@ -378,8 +379,8 @@ def update_mtp_accepted_token_num(self, accept_token_num: int): def get_last_gen_token(self): return self.shm_req.shm_prompt_ids.arr[self.shm_req.input_len + self.cur_output_len - 1] - def update_finish_status(self, eos_ids, output_len: int): - if self._stop_sequences_matched(output_len=output_len): + def update_finish_status(self, eos_ids, output_len: int, tokenizer=None): + if self._stop_sequences_matched(output_len=output_len) or self._stop_sequences_str_matched(tokenizer): self.finish_status.set_status(FinishStatus.FINISHED_STOP) elif ( output_len > 0 @@ -404,6 +405,28 @@ def _stop_sequences_matched(self, output_len: int): return True return False + def _stop_sequences_str_matched(self, tokenizer): + if not self.stop_sequences_str or tokenizer is None: + return False + + max_stop_str_len = max(len(stop_str) for stop_str in self.stop_sequences_str) if self.stop_sequences_str else 0 + if max_stop_str_len == 0: + return False + + tail_token_len = min(self.cur_output_len, max_stop_str_len + 10) # +10 for safety + if tail_token_len > 0: + tail_token_ids = self.shm_req.shm_prompt_ids.arr[ + (self.shm_req.input_len + self.cur_output_len - tail_token_len) : ( + self.shm_req.input_len + self.cur_output_len + ) + ] + tail_str = tokenizer.decode(tail_token_ids, skip_special_tokens=False) + for stop_str in self.stop_sequences_str: + if stop_str in tail_str: + logger.info(f"Found stop sequence in tail: stop_str='{stop_str}', tail_str='{tail_str}'") + return True + return False + def prefill_need_token_num(self, is_chuncked_prefill: bool): if is_chuncked_prefill: input_token_ids = self.get_chuncked_input_token_ids() @@ -482,6 +505,7 @@ def handle( eos_ids: List[int], extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]], is_master_in_dp: bool, + tokenizer=None, ): if self.output_len <= 0: return @@ -503,7 +527,7 @@ def handle( return # 更新判断请求的 finished 状态 - req_obj.update_finish_status(eos_ids=eos_ids, output_len=self.output_len) + req_obj.update_finish_status(eos_ids=eos_ids, output_len=self.output_len, tokenizer=tokenizer) if extra_post_req_handle_func is not None: extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index ce7bb53cf..70214d1c7 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -26,11 +26,12 @@ from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size from lightllm.utils.dist_utils import get_dp_rank_in_node, create_new_group_for_current_node -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, enable_stop_string_match from lightllm.distributed import dist_group_manager from lightllm.server.router.shm_reqs_io_buffer import ShmReqsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel +from lightllm.server.tokenizer import get_tokenizer class ModeBackend: @@ -504,6 +505,14 @@ def _post_handle( extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。 """ + if enable_stop_string_match(): + if not hasattr(self, "tokenizer"): + self.tokenizer = get_tokenizer( + self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code + ) + else: + self.tokenizer = None + for req_obj, next_token_id, next_token_logprob, pack in zip( run_reqs, next_token_ids, next_token_logprobs, run_reqs_update_packs ): @@ -515,6 +524,7 @@ def _post_handle( eos_ids=self.eos_id, extra_post_req_handle_func=extra_post_req_handle_func, is_master_in_dp=self.is_master_in_dp, + tokenizer=self.tokenizer, ) g_infer_context.req_manager.req_sampling_params_manager.update_reqs_token_counter( diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index b78784d82..2b8563e9c 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -68,6 +68,11 @@ def get_lightllm_gunicorn_keep_alive(): return int(os.getenv("LIGHTLMM_GUNICORN_KEEP_ALIVE", 10)) +@lru_cache(maxsize=None) +def enable_stop_string_match(): + return os.getenv("ENABLE_STOP_STRING_MATCH", "False").upper() in ["ON", "TRUE", "1"] + + @lru_cache(maxsize=None) def get_lightllm_websocket_max_message_size(): """