Skip to content

Commit

Permalink
handle client forward error
Browse files Browse the repository at this point in the history
  • Loading branch information
wnma3mz committed Feb 2, 2025
1 parent b4fa0fb commit 9d127c6
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tllm/commons/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def build(self, seq_input: SeqInput, cache_manager: "CacheManager"):
# 命中了之前的 kv cache,使用历史 cache
if hit_uuid is not None and cache_manager.get(hit_uuid) is not None:
hid_decoder_cache: DecoderCache = copy.deepcopy(cache_manager.get(hit_uuid))
# 相同请求时,避免过超过 cache 长度
# 相同输入时,避免过超过 cache 长度
if q_len <= hit_cache_len:
hit_cache_len = q_len - 1
hit_cache_len = q_len - 2

hid_decoder_cache.truncate(hit_cache_len)
hid_decoder_cache.set_q_len(q_len - hit_cache_len)
Expand Down
10 changes: 9 additions & 1 deletion tllm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,16 @@ async def _generate(self):

# await asyncio.sleep(self.sleep_time)
except asyncio.CancelledError:
self.logger.debug("CancelledError")
# LLM Generate Error or Server Cancel Engine
self.logger.error("CancelledError")
for request_data in request_data_list:
request_data.is_stop = True
request_data.finish_reason_list = ["Server Error"]
async with request_data.condition:
request_data.condition.notify()
traceback.print_exc()
except Exception as e:
self.logger.error(f"Error input_ids: {'\n'.join(x.input_ids for x in request_data_list)}")
self.logger.error(f"Error processing prefill_queue data: {str(e)}")
traceback.print_exc()
except BaseException as e:
Expand Down
8 changes: 8 additions & 0 deletions tllm/grpc/master_service/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from tllm.grpc.master_service.pending_requests import PendingRequests
from tllm.grpc.proto import schemas_pb2, schemas_pb2_grpc
from tllm.schemas import MIX_TENSOR, SeqInput
from tllm.singleton_logger import SingletonLogger

logger = SingletonLogger.setup_master_logger()


async def rpc_image_forward(
Expand Down Expand Up @@ -107,6 +110,11 @@ async def forward(self, hidden_states: MIX_TENSOR, seq_input: SeqInput) -> Tuple
try:
output = await asyncio.wait_for(forward_future, timeout=PP_TIMEOUT)
except asyncio.CancelledError:
logger.error("Client Timeout Error")
raise asyncio.CancelledError

if output.data == b"":
logger.error("Client Forward Error")
raise asyncio.CancelledError

return convertor.deserialize(output), await asyncio.wait_for(status_future, timeout=PP_TIMEOUT)
Expand Down
13 changes: 12 additions & 1 deletion tllm/grpc/worker_service/worker_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ async def stop(self):
pass

async def forward_func(self, request: schemas_pb2.ForwardRequest):
# 如果收到为空,则直接返回
if request.hidden_states.data == b"":
if self.comm.is_rank0():
await self.master_rpc_manager.rpc_func(request, None, -1)
return
s1 = time.perf_counter()

convertor = Convertor()
Expand All @@ -85,7 +90,13 @@ async def forward_func(self, request: schemas_pb2.ForwardRequest):
self.comm.debug_rank0(f"deserialize_tensor cost time: {time.perf_counter() - s1:.4f}")

s1 = time.perf_counter()
output_hidden_states = self.model(hidden_states, seq_input)
try:
output_hidden_states = self.model(hidden_states, seq_input)
except ValueError as e:
self.logger.error(f"forward_func ValueError by {str(e)}")
if self.comm.is_rank0():
await self.master_rpc_manager.rpc_func(request, None, -1)
return
cost_time = time.perf_counter() - s1
self.comm.debug_rank0(f"forward cost time: {cost_time:.4f}")

Expand Down

0 comments on commit 9d127c6

Please sign in to comment.