Skip to content

feat: Support decode chunk PD serving mode #944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ dist
*.egg-info
.idea
.vscode
tmp/
tmp/
9 changes: 9 additions & 0 deletions docs/CN/source/tutorial/api_server_args_zh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ PD 分离模式参数

配置服务器模式下的端口号


.. option:: --chunked_max_new_token

分块解码最大 token 数量,默认为 ``0`` ,代表不使用分块解码

.. option:: --pd_max_retry_count

PD 模式下 kv 传输失败的最大重试次数,默认为 ``3``

模型配置参数
-----------

Expand Down
8 changes: 8 additions & 0 deletions docs/EN/source/tutorial/api_server_args_zh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ PD disaggregation Mode Parameters

Port number in configuration server mode

.. option:: --chunked_max_new_token

Maximum token number for chunked decoding, default is ``0``, representing no chunked decoding

.. option:: --pd_max_retry_count

Maximum retry count for kv transmission in PD mode, default is ``3``

Model Configuration Parameters
-----------------------------

Expand Down
12 changes: 12 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ def make_argument_parser() -> argparse.ArgumentParser:
default="default_model_name",
help="just help to distinguish internal model name, use 'host:port/get_model_name' to get",
)
parser.add_argument(
"--chunked_max_new_token",
type=int,
default=0,
help="""Specifies the chunk size for pd mode.""",
)
parser.add_argument(
"--pd_max_retry_count",
type=int,
default=3,
help="""Specifies the max retry count for pd mode.""",
)

parser.add_argument(
"--model_dir",
Expand Down
134 changes: 108 additions & 26 deletions lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lightllm.server.metrics.manager import MetricClient
from lightllm.utils.statics_utils import MovingAverage
from lightllm.server.httpserver.manager import AsyncQueue
from lightllm.utils.error_utils import ServerBusyError
from lightllm.utils.error_utils import ServerBusyError, KVMoveTimeoutError

logger = init_logger(__name__)

Expand Down Expand Up @@ -123,6 +123,9 @@ async def generate(
):
start_time = time.time()
group_request_id = self.id_gen.generate_id()
max_retries = self.args.pd_max_retry_count
retry_count = 0

try:
sampling_params.group_request_id = group_request_id
# 记录请求到达的相关信息
Expand All @@ -131,22 +134,41 @@ async def generate(
self.metric_client.counter_inc("lightllm_request_count")
self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens)

p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params)

results_generator = self._wait_to_token_package(
p_node,
d_node,
start_time,
prompt,
sampling_params,
multimodal_params,
request,
)
async for sub_req_id, request_output, metadata, finish_status in results_generator:
yield sub_req_id, request_output, metadata, finish_status
while retry_count <= max_retries:
try:
p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params)

results_generator = self._wait_to_token_package(
p_node,
d_node,
start_time,
prompt,
sampling_params,
multimodal_params,
request,
)
async for sub_req_id, request_output, metadata, finish_status in results_generator:
yield sub_req_id, request_output, metadata, finish_status

break

except KVMoveTimeoutError as e:
retry_count += 1
if retry_count <= max_retries:
logger.warning(f"KV move timeout for group_request_id {group_request_id}, attempt {retry_count}/{max_retries}. Retrying with new nodes...")
# 清理当前请求状态,准备重试
await self.abort(group_request_id)
# 重新生成group_request_id避免冲突
group_request_id = self.id_gen.generate_id()
sampling_params.group_request_id = group_request_id
continue
else:
logger.error(f"KV move timeout after {max_retries + 1} attempts for group_request_id {group_request_id}. Giving up.")
raise ServerBusyError(f"KV move timeout after {max_retries + 1} attempts, server is busy now.")

except BaseException as e:
logger.error(f"has exception {str(e)}")
if not isinstance(e, KVMoveTimeoutError):
logger.error(f"has exception {str(e)}")
await self.abort(group_request_id)
raise e

Expand Down Expand Up @@ -190,6 +212,7 @@ async def fetch_stream(
self.req_id_to_out_inf[group_request_id] = req_status

up_status_event = req_status.up_status_event
up_status_event.clear()

d_start_args = d_node.start_args
decode_node_dict = {
Expand Down Expand Up @@ -234,22 +257,81 @@ async def fetch_stream(
await asyncio.wait_for(up_status_event.wait(), timeout=60)
except asyncio.TimeoutError:
logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.")
raise ServerBusyError()
raise KVMoveTimeoutError(f"KV move timeout for group_request_id {group_request_id}")

sampling_params.move_kv_to_decode_node.initialize(None)
sampling_params.max_new_tokens = old_max_new_tokens - 1
sampling_params.suggested_dp_index = up_status_event.upkv_status.dp_index

await d_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt_ids, sampling_params, multimodal_params))))
remaining_tokens = old_max_new_tokens - 1
chunked_max_new_token = self.args.chunked_max_new_token
current_prompt_ids = list(prompt_ids)

while True:
await req_status.wait_to_ready()
if await request.is_disconnected():
raise Exception(f"req_id {group_request_id} disconnected")
if await req_status.can_read(self.req_id_to_out_inf):
token_list = await req_status.pop_all_tokens()
for sub_req_id, request_output, metadata, finish_status in token_list:
yield sub_req_id, request_output, metadata, finish_status
while remaining_tokens > 0:
chunk_size = min(remaining_tokens, chunked_max_new_token) if chunked_max_new_token > 0 else remaining_tokens
sampling_params.max_new_tokens = chunk_size

await d_node.websocket.send_bytes(
pickle.dumps((ObjType.REQ, (current_prompt_ids, sampling_params, multimodal_params)))
)

chunk_finished = False
while not chunk_finished:
await req_status.wait_to_ready()
if await request.is_disconnected():
raise Exception(f"req_id {group_request_id} disconnected")

if await req_status.can_read(self.req_id_to_out_inf):
token_list = await req_status.pop_all_tokens()
for sub_req_id, request_output, metadata, finish_status in token_list:
current_prompt_ids.append(metadata.get("id"))
remaining_tokens -= 1

final_finish_status = finish_status

# reach max new tokens, really finished
if remaining_tokens == 0:
final_finish_status = FinishStatus(FinishStatus.FINISHED_LENGTH)
chunk_finished = True
# reach stop token, really finished
elif finish_status == FinishStatus.FINISHED_STOP:
final_finish_status = FinishStatus(FinishStatus.FINISHED_STOP)
chunk_finished = True
# reach chunk size, not really finished
elif finish_status == FinishStatus.FINISHED_LENGTH:
final_finish_status = FinishStatus(FinishStatus.NO_FINISH)
chunk_finished = True

yield sub_req_id, request_output, metadata, final_finish_status

if final_finish_status.is_finished():
break

# 如果不是最后一个chunk,需要将KV Cache从decode节点发送回prefill节点
if remaining_tokens > 0:
up_status_event = req_status.up_status_event
up_status_event.clear()
p_start_args = p_node.start_args
prefill_node_dict = {
"node_id": p_start_args["pd_node_id"],
"ip": p_start_args["host"],
"rpyc_port": p_start_args["pd_decode_rpyc_port"],
"max_new_tokens": 0,
"pd_master_node_id": self.args.pd_node_id,
}

sampling_params.max_new_tokens = 0
sampling_params.move_kv_to_decode_node.initialize(prefill_node_dict)
sampling_params.suggested_dp_index = -1

await p_node.websocket.send_bytes(
pickle.dumps((ObjType.REQ, (current_prompt_ids, sampling_params, multimodal_params)))
)

try:
await asyncio.wait_for(up_status_event.wait(), timeout=60)
except asyncio.TimeoutError:
logger.warning(f"group_request_id: {group_request_id} kv move back time out err, server is busy now.")
raise KVMoveTimeoutError(f"KV move back timeout for group_request_id {group_request_id}")

return

Expand Down
5 changes: 5 additions & 0 deletions lightllm/utils/error_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@ def __init__(self, message="Server is busy, please try again later", status_code
def __str__(self):
"""String representation of the error"""
return f"{self.message} (Status code: {self.status_code})"

class KVMoveTimeoutError(ServerBusyError):
"""KV移动超时错误,用于触发重试机制"""
def __init__(self, message="KV move timeout, please try again later", status_code=503):
super().__init__(message, status_code)