Skip to content

Commit b93b511

Browse files
author
niushengxiao
committed
feat: add stop string matching
1 parent fc72ffa commit b93b511

File tree

5 files changed

+91
-17
lines changed

5 files changed

+91
-17
lines changed

lightllm/server/api_openai.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,9 @@ async def process_single_prompt(prompt: Union[str, List[int]], prompt_index: int
428428
prompt, individual_sampling_params, multimodal_params, request=raw_request
429429
)
430430

431-
return await _collect_generation_results(generator, request, prompt_str, prompt_index)
431+
return await _collect_generation_results(
432+
generator, request, prompt_str, prompt_index, individual_sampling_params
433+
)
432434

433435
tasks = [asyncio.create_task(process_single_prompt(prompt, i)) for i, prompt in enumerate(prompts)]
434436

@@ -487,7 +489,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
487489
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)
488490

489491

490-
async def _collect_generation_results(generator, request: CompletionRequest, prompt: str, prompt_index: int):
492+
async def _collect_generation_results(
493+
generator, request: CompletionRequest, prompt: str, prompt_index: int, sampling_params: SamplingParams
494+
):
491495
final_output = []
492496
count_output_tokens = 0
493497
finish_reason = None
@@ -518,9 +522,22 @@ async def _collect_generation_results(generator, request: CompletionRequest, pro
518522
finish_reason = finish_status.get_finish_reason()
519523
prompt_tokens = metadata["prompt_tokens"]
520524

525+
# 处理停止序列剔除
526+
final_text = "".join(final_output)
527+
if finish_reason == "stop" and hasattr(sampling_params, "stop_sequences") and sampling_params.stop_sequences:
528+
stop_strings = sampling_params.stop_sequences.to_string()
529+
valid_stop_strings = [s for s in stop_strings if s]
530+
if valid_stop_strings:
531+
text_len = len(final_text)
532+
for stop_str in valid_stop_strings:
533+
stop_len = len(stop_str)
534+
if text_len >= stop_len and final_text.endswith(stop_str):
535+
final_text = final_text[:-stop_len]
536+
break
537+
521538
return {
522539
"index": prompt_index,
523-
"text": "".join(final_output),
540+
"text": final_text,
524541
"finish_reason": finish_reason,
525542
"prompt_tokens": prompt_tokens,
526543
"completion_tokens": count_output_tokens,

lightllm/server/core/objs/sampling_params.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
# 从环境变量获取最大长度限制
1212
STOP_SEQUENCE_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_MAX_LENGTH", 256))
13+
STOP_SEQUENCE_STR_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_STR_MAX_LENGTH", 256))
1314
ALLOWED_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH", 256))
1415
MAX_STOP_SEQUENCES = int(os.getenv("LIGHTLLM_MAX_STOP_SEQUENCES", 10))
1516
REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048))
@@ -22,17 +23,27 @@ class StopSequence(ctypes.Structure):
2223
_fields_ = [
2324
("sequence", ctypes.c_int * STOP_SEQUENCE_MAX_LENGTH),
2425
("size", ctypes.c_int),
26+
("sequence_str", ctypes.c_char * STOP_SEQUENCE_STR_MAX_LENGTH),
27+
("sequence_str_len", ctypes.c_int),
2528
]
2629

27-
def initialize(self, sequence: List[int]):
30+
def initialize(self, sequence: List[int], sequence_str: str = ""):
2831
self.size = len(sequence)
2932
assert self.size <= STOP_SEQUENCE_MAX_LENGTH, "stop token length too long."
3033
assert all(isinstance(e, int) for e in sequence), "all must be int"
3134
self.sequence[: self.size] = sequence[:]
3235

36+
sequence_str_bytes = sequence_str.encode("utf-8")
37+
assert len(sequence_str_bytes) < STOP_SEQUENCE_STR_MAX_LENGTH, "stop sequence string too long."
38+
self.sequence_str = sequence_str_bytes
39+
self.sequence_str_len = len(sequence_str_bytes)
40+
3341
def to_list(self):
3442
return list(self.sequence[0 : self.size])
3543

44+
def to_string(self):
45+
return bytes(self.sequence_str[0 : self.sequence_str_len]).decode("utf-8")
46+
3647

3748
class StopSequenceGroups(ctypes.Structure):
3849
_pack_ = 4
@@ -45,8 +56,10 @@ def initialize(self, stop_sequences: Union[str, List], tokenizer):
4556
groups: List[List[int]] = self.stop_sentences_to_token_ids(stop_sequences, tokenizer)
4657
self.size = len(groups)
4758
assert self.size <= MAX_STOP_SEQUENCES, "Too many stop sequence groups."
59+
if isinstance(stop_sequences, str):
60+
stop_sequences = [stop_sequences]
4861
for group_idx in range(self.size):
49-
self.groups[group_idx].initialize(groups[group_idx])
62+
self.groups[group_idx].initialize(groups[group_idx], stop_sequences[group_idx])
5063

5164
def stop_sentences_to_token_ids(self, stop_sequences, tokenizer):
5265
if stop_sequences is None:
@@ -75,6 +88,10 @@ def _stop_str_to_token_ids(self, stop_str: str, tokenizer):
7588
def to_list(self):
7689
return [self.groups[i].to_list() for i in range(self.size)]
7790

91+
def to_string(self):
92+
# 降序匹配,在出现"\n\n"和"\n"情况时,优先匹配“\n\n”
93+
return sorted([self.groups[i].to_string() for i in range(self.size)], key=len, reverse=True)
94+
7895

7996
class RegularConstraint(ctypes.Structure):
8097
_pack_ = 4

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def init_all(self):
277277
g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self)
278278

279279
self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list()
280+
self.stop_sequences_str = self.sampling_param.shm_param.stop_sequences.to_string()
280281
# token healing mode 才被使用的管理对象
281282
if self.shm_req.prefix_token_ids.size != 0:
282283
self.prefix_token_ids = self.shm_req.prefix_token_ids.get_token_ids()
@@ -344,8 +345,8 @@ def update_mtp_accepted_token_num(self, accept_token_num: int):
344345
def get_last_gen_token(self):
345346
return self.shm_req.shm_prompt_ids.arr[self.shm_req.input_len + self.cur_output_len - 1]
346347

347-
def update_finish_status(self, eos_ids):
348-
if self._stop_sequences_matched():
348+
def update_finish_status(self, eos_ids, tokenizer=None):
349+
if self._stop_sequences_matched() or self._stop_sequences_str_matched(tokenizer):
349350
self.finish_status.set_status(FinishStatus.FINISHED_STOP)
350351
elif (
351352
self.cur_output_len > 0
@@ -373,6 +374,28 @@ def _stop_sequences_matched(self):
373374
return True
374375
return False
375376

377+
def _stop_sequences_str_matched(self, tokenizer):
378+
if not self.stop_sequences_str or tokenizer is None:
379+
return False
380+
381+
max_stop_str_len = max(len(stop_str) for stop_str in self.stop_sequences_str) if self.stop_sequences_str else 0
382+
if max_stop_str_len == 0:
383+
return False
384+
385+
tail_token_len = min(self.cur_output_len, max_stop_str_len + 10) # +10 for safety
386+
if tail_token_len > 0:
387+
tail_token_ids = self.shm_req.shm_prompt_ids.arr[
388+
(self.shm_req.input_len + self.cur_output_len - tail_token_len) : (
389+
self.shm_req.input_len + self.cur_output_len
390+
)
391+
]
392+
tail_str = tokenizer.decode(tail_token_ids, skip_special_tokens=False)
393+
for stop_str in self.stop_sequences_str:
394+
if stop_str in tail_str:
395+
logger.info(f"Found stop sequence in tail: stop_str='{stop_str}', tail_str='{tail_str}'")
396+
return True
397+
return False
398+
376399

377400
class InferReqGroup:
378401
def __init__(

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
from lightllm.common.basemodel.batch_objs import ModelOutput
1515
from lightllm.utils.dist_utils import init_distributed_env
1616
from lightllm.utils.envs_utils import get_unique_server_name
17+
from lightllm.utils.envs_utils import enable_stop_string_match
1718
from lightllm.server.core.objs import ShmReqManager, StartArgs
1819
from lightllm.server.router.model_infer.infer_batch import g_infer_context
20+
from lightllm.server.tokenizer import get_tokenizer
1921
from lightllm.utils.dist_utils import get_global_rank, get_global_world_size, get_dp_size
2022
from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp
2123
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size
@@ -303,7 +305,14 @@ def _post_handle(
303305
continue
304306

305307
# 更新判断请求的 finished 状态
306-
req_obj.update_finish_status(self.eos_id)
308+
if enable_stop_string_match():
309+
if not hasattr(self, "tokenizer"):
310+
self.tokenizer = get_tokenizer(
311+
self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code
312+
)
313+
req_obj.update_finish_status(self.eos_id, self.tokenizer)
314+
else:
315+
req_obj.update_finish_status(self.eos_id)
307316

308317
if extra_post_req_handle_func is not None:
309318
extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob)

lightllm/utils/envs_utils.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ def get_lightllm_gunicorn_keep_alive():
6868
return int(os.getenv("LIGHTLMM_GUNICORN_KEEP_ALIVE", 10))
6969

7070

71+
@lru_cache(maxsize=None)
72+
def enable_stop_string_match():
73+
return os.getenv("ENABLE_STOP_STRING_MATCH", "False").upper() in ["ON", "TRUE", "1"]
74+
75+
7176
@lru_cache(maxsize=None)
7277
def get_lightllm_websocket_max_message_size():
7378
"""
@@ -77,15 +82,18 @@ def get_lightllm_websocket_max_message_size():
7782
return int(os.getenv("LIGHTLLM_WEBSOCKET_MAX_SIZE", 16 * 1024 * 1024))
7883

7984

80-
# get_redundancy_expert_ids and get_redundancy_expert_num are primarily used to obtain the IDs and number of redundant experts during inference.
81-
# They depend on a configuration file specified by ep_redundancy_expert_config_path, which is a JSON formatted text file.
82-
# The content format is as follows:
83-
# {
84-
# "redundancy_expert_num": 1, # Number of redundant experts per rank
85-
# "0": [0], # Key: layer_index (string), Value: list of original expert IDs that are redundant for this layer
86-
# "1": [0],
87-
# "default": [0] # Default list of redundant expert IDs if layer-specific entry is not found
88-
# }
85+
# get_redundancy_expert_ids and get_redundancy_expert_num are primarily
86+
# used to obtain the IDs and number of redundant experts during inference.
87+
# They depend on a configuration file specified by ep_redundancy_expert_config_path,
88+
# which is a JSON formatted text file.
89+
# The content format is as follows:
90+
# {
91+
# "redundancy_expert_num": 1, # Number of redundant experts per rank
92+
# "0": [0], # Key: layer_index (string),
93+
# # Value: list of original expert IDs that are redundant for this layer
94+
# "1": [0],
95+
# "default": [0] # Default list of redundant expert IDs if layer-specific entry is not found
96+
# }
8997

9098

9199
@lru_cache(maxsize=None)

0 commit comments

Comments
 (0)