Skip to content

feat: add stop string matching #969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions lightllm/server/api_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,9 @@ async def process_single_prompt(prompt: Union[str, List[int]], prompt_index: int
prompt, individual_sampling_params, multimodal_params, request=raw_request
)

return await _collect_generation_results(generator, request, prompt_str, prompt_index)
return await _collect_generation_results(
generator, request, prompt_str, prompt_index, individual_sampling_params
)

tasks = [asyncio.create_task(process_single_prompt(prompt, i)) for i, prompt in enumerate(prompts)]

Expand Down Expand Up @@ -487,7 +489,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)


async def _collect_generation_results(generator, request: CompletionRequest, prompt: str, prompt_index: int):
async def _collect_generation_results(
generator, request: CompletionRequest, prompt: str, prompt_index: int, sampling_params: SamplingParams
):
final_output = []
count_output_tokens = 0
finish_reason = None
Expand Down Expand Up @@ -518,9 +522,30 @@ async def _collect_generation_results(generator, request: CompletionRequest, pro
finish_reason = finish_status.get_finish_reason()
prompt_tokens = metadata["prompt_tokens"]

# 处理停止序列剔除
final_text = "".join(final_output)
if finish_reason == "stop" and sampling_params.stop_sequences.size > 0:
stop_strings = sampling_params.stop_sequences.to_string()
valid_stop_strings = [s for s in stop_strings if s]
if valid_stop_strings:
max_stop_len = len(valid_stop_strings[0])
search_len = min(len(final_text), max_stop_len + 20) # 搜索长度为最长停止序列长度加20
tail_text = final_text[-search_len:] if search_len > 0 else final_text
tail_start_pos = len(final_text) - search_len
earliest_stop_index = len(final_text)
for stop_str in valid_stop_strings:
stop_index = tail_text.find(stop_str)
if stop_index != -1:
actual_stop_index = tail_start_pos + stop_index
if actual_stop_index < earliest_stop_index:
earliest_stop_index = actual_stop_index

if earliest_stop_index < len(final_text):
final_text = final_text[:earliest_stop_index]

return {
"index": prompt_index,
"text": "".join(final_output),
"text": final_text,
"finish_reason": finish_reason,
"prompt_tokens": prompt_tokens,
"completion_tokens": count_output_tokens,
Expand Down
21 changes: 19 additions & 2 deletions lightllm/server/core/objs/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

# 从环境变量获取最大长度限制
STOP_SEQUENCE_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_MAX_LENGTH", 256))
STOP_SEQUENCE_STR_MAX_LENGTH = int(os.getenv("LIGHTLLM_STOP_SEQUENCE_STR_MAX_LENGTH", 256))
ALLOWED_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_ALLOWED_TOKEN_IDS_MAX_LENGTH", 256))
MAX_STOP_SEQUENCES = int(os.getenv("LIGHTLLM_MAX_STOP_SEQUENCES", 10))
REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048))
Expand All @@ -22,17 +23,27 @@ class StopSequence(ctypes.Structure):
_fields_ = [
("sequence", ctypes.c_int * STOP_SEQUENCE_MAX_LENGTH),
("size", ctypes.c_int),
("sequence_str", ctypes.c_char * STOP_SEQUENCE_STR_MAX_LENGTH),
("sequence_str_len", ctypes.c_int),
]

def initialize(self, sequence: List[int]):
def initialize(self, sequence: List[int], sequence_str: str = ""):
self.size = len(sequence)
assert self.size <= STOP_SEQUENCE_MAX_LENGTH, "stop token length too long."
assert all(isinstance(e, int) for e in sequence), "all must be int"
self.sequence[: self.size] = sequence[:]

sequence_str_bytes = sequence_str.encode("utf-8")
assert len(sequence_str_bytes) < STOP_SEQUENCE_STR_MAX_LENGTH, "stop sequence string too long."
self.sequence_str = sequence_str_bytes
self.sequence_str_len = len(sequence_str_bytes)

def to_list(self):
return list(self.sequence[0 : self.size])

def to_string(self):
return bytes(self.sequence_str[0 : self.sequence_str_len]).decode("utf-8")


class StopSequenceGroups(ctypes.Structure):
_pack_ = 4
Expand All @@ -45,8 +56,10 @@ def initialize(self, stop_sequences: Union[str, List], tokenizer):
groups: List[List[int]] = self.stop_sentences_to_token_ids(stop_sequences, tokenizer)
self.size = len(groups)
assert self.size <= MAX_STOP_SEQUENCES, "Too many stop sequence groups."
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]
for group_idx in range(self.size):
self.groups[group_idx].initialize(groups[group_idx])
self.groups[group_idx].initialize(groups[group_idx], stop_sequences[group_idx])
Comment on lines +59 to +62

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a potential issue in this loop. If stop_sequences contains a list of token IDs (e.g., [1, 2, 3]), it will be used as the sequence_str argument to StopSequence.initialize, which expects a string. This can lead to incorrect behavior or a crash. Ensure that only strings are passed as sequence_str.


def stop_sentences_to_token_ids(self, stop_sequences, tokenizer):
if stop_sequences is None:
Expand Down Expand Up @@ -75,6 +88,10 @@ def _stop_str_to_token_ids(self, stop_str: str, tokenizer):
def to_list(self):
return [self.groups[i].to_list() for i in range(self.size)]

def to_string(self):
# 降序匹配,在出现"\n\n"和"\n"情况时,优先匹配“\n\n”
return sorted([self.groups[i].to_string() for i in range(self.size)], key=len, reverse=True)


class RegularConstraint(ctypes.Structure):
_pack_ = 4
Expand Down
27 changes: 25 additions & 2 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def init_all(self):
g_infer_context.req_manager.req_sampling_params_manager.init_req_sampling_params(self)

self.stop_sequences = self.sampling_param.shm_param.stop_sequences.to_list()
self.stop_sequences_str = self.sampling_param.shm_param.stop_sequences.to_string()
# token healing mode 才被使用的管理对象
if self.shm_req.prefix_token_ids.size != 0:
self.prefix_token_ids = self.shm_req.prefix_token_ids.get_token_ids()
Expand Down Expand Up @@ -344,8 +345,8 @@ def update_mtp_accepted_token_num(self, accept_token_num: int):
def get_last_gen_token(self):
return self.shm_req.shm_prompt_ids.arr[self.shm_req.input_len + self.cur_output_len - 1]

def update_finish_status(self, eos_ids):
if self._stop_sequences_matched():
def update_finish_status(self, eos_ids, tokenizer=None):
if self._stop_sequences_matched() or self._stop_sequences_str_matched(tokenizer):
self.finish_status.set_status(FinishStatus.FINISHED_STOP)
elif (
self.cur_output_len > 0
Expand Down Expand Up @@ -373,6 +374,28 @@ def _stop_sequences_matched(self):
return True
return False

def _stop_sequences_str_matched(self, tokenizer):
if not self.stop_sequences_str or tokenizer is None:
return False

max_stop_str_len = max(len(stop_str) for stop_str in self.stop_sequences_str) if self.stop_sequences_str else 0
if max_stop_str_len == 0:
return False

tail_token_len = min(self.cur_output_len, max_stop_str_len + 10) # +10 for safety
if tail_token_len > 0:
tail_token_ids = self.shm_req.shm_prompt_ids.arr[
(self.shm_req.input_len + self.cur_output_len - tail_token_len) : (
self.shm_req.input_len + self.cur_output_len
)
]
tail_str = tokenizer.decode(tail_token_ids, skip_special_tokens=False)
for stop_str in self.stop_sequences_str:
if stop_str in tail_str:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using in to check for the stop string can lead to premature termination. If the stop string is "User:" and the model generates "Assistant: Hello! User: Hi!", the generation would stop prematurely. Use tail_str.endswith(stop_str) to match only at the end.

Suggested change
if stop_str in tail_str:
if tail_str.endswith(stop_str):

logger.info(f"Found stop sequence in tail: stop_str='{stop_str}', tail_str='{tail_str}'")
return True
return False


class InferReqGroup:
def __init__(
Expand Down
11 changes: 10 additions & 1 deletion lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from lightllm.common.basemodel.batch_objs import ModelOutput
from lightllm.utils.dist_utils import init_distributed_env
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.utils.envs_utils import enable_stop_string_match
from lightllm.server.core.objs import ShmReqManager, StartArgs
from lightllm.server.router.model_infer.infer_batch import g_infer_context
from lightllm.server.tokenizer import get_tokenizer
from lightllm.utils.dist_utils import get_global_rank, get_global_world_size, get_dp_size
from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size
Expand Down Expand Up @@ -303,7 +305,14 @@ def _post_handle(
continue

# 更新判断请求的 finished 状态
req_obj.update_finish_status(self.eos_id)
if enable_stop_string_match():
if not hasattr(self, "tokenizer"):
self.tokenizer = get_tokenizer(
self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code
)
req_obj.update_finish_status(self.eos_id, self.tokenizer)
else:
req_obj.update_finish_status(self.eos_id)

if extra_post_req_handle_func is not None:
extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob)
Expand Down
5 changes: 5 additions & 0 deletions lightllm/utils/envs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def get_lightllm_gunicorn_keep_alive():
return int(os.getenv("LIGHTLMM_GUNICORN_KEEP_ALIVE", 10))


@lru_cache(maxsize=None)
def enable_stop_string_match():
return os.getenv("ENABLE_STOP_STRING_MATCH", "False").upper() in ["ON", "TRUE", "1"]


@lru_cache(maxsize=None)
def get_lightllm_websocket_max_message_size():
"""
Expand Down