diff --git a/tllm/commons/cache.py b/tllm/commons/cache.py index d2f5af7..c2deebc 100644 --- a/tllm/commons/cache.py +++ b/tllm/commons/cache.py @@ -116,9 +116,9 @@ def build(self, seq_input: SeqInput, cache_manager: "CacheManager"): # 命中了之前的 kv cache,使用历史 cache if hit_uuid is not None and cache_manager.get(hit_uuid) is not None: hid_decoder_cache: DecoderCache = copy.deepcopy(cache_manager.get(hit_uuid)) - # 相同请求时,避免过超过 cache 长度 + # 相同输入时,避免过超过 cache 长度 if q_len <= hit_cache_len: - hit_cache_len = q_len - 1 + hit_cache_len = q_len - 2 hid_decoder_cache.truncate(hit_cache_len) hid_decoder_cache.set_q_len(q_len - hit_cache_len) diff --git a/tllm/engine.py b/tllm/engine.py index 4d9ce89..f24196e 100644 --- a/tllm/engine.py +++ b/tllm/engine.py @@ -84,8 +84,16 @@ async def _generate(self): # await asyncio.sleep(self.sleep_time) except asyncio.CancelledError: - self.logger.debug("CancelledError") + # LLM Generate Error or Server Cancel Engine + self.logger.error("CancelledError") + for request_data in request_data_list: + request_data.is_stop = True + request_data.finish_reason_list = ["Server Error"] + async with request_data.condition: + request_data.condition.notify() + traceback.print_exc() except Exception as e: + self.logger.error(f"Error input_ids: {'\n'.join(x.input_ids for x in request_data_list)}") self.logger.error(f"Error processing prefill_queue data: {str(e)}") traceback.print_exc() except BaseException as e: diff --git a/tllm/grpc/master_service/worker_manager.py b/tllm/grpc/master_service/worker_manager.py index 3e8f807..0a0cdc4 100644 --- a/tllm/grpc/master_service/worker_manager.py +++ b/tllm/grpc/master_service/worker_manager.py @@ -10,6 +10,9 @@ from tllm.grpc.master_service.pending_requests import PendingRequests from tllm.grpc.proto import schemas_pb2, schemas_pb2_grpc from tllm.schemas import MIX_TENSOR, SeqInput +from tllm.singleton_logger import SingletonLogger + +logger = SingletonLogger.setup_master_logger() async def rpc_image_forward( @@ -107,6 +110,11 @@ async def forward(self, hidden_states: MIX_TENSOR, seq_input: SeqInput) -> Tuple try: output = await asyncio.wait_for(forward_future, timeout=PP_TIMEOUT) except asyncio.CancelledError: + logger.error("Client Timeout Error") + raise asyncio.CancelledError + + if output.data == b"": + logger.error("Client Forward Error") raise asyncio.CancelledError return convertor.deserialize(output), await asyncio.wait_for(status_future, timeout=PP_TIMEOUT) diff --git a/tllm/grpc/worker_service/worker_server.py b/tllm/grpc/worker_service/worker_server.py index cc05422..5f24a44 100644 --- a/tllm/grpc/worker_service/worker_server.py +++ b/tllm/grpc/worker_service/worker_server.py @@ -75,6 +75,11 @@ async def stop(self): pass async def forward_func(self, request: schemas_pb2.ForwardRequest): + # 如果收到为空,则直接返回 + if request.hidden_states.data == b"": + if self.comm.is_rank0(): + await self.master_rpc_manager.rpc_func(request, None, -1) + return s1 = time.perf_counter() convertor = Convertor() @@ -85,7 +90,13 @@ async def forward_func(self, request: schemas_pb2.ForwardRequest): self.comm.debug_rank0(f"deserialize_tensor cost time: {time.perf_counter() - s1:.4f}") s1 = time.perf_counter() - output_hidden_states = self.model(hidden_states, seq_input) + try: + output_hidden_states = self.model(hidden_states, seq_input) + except ValueError as e: + self.logger.error(f"forward_func ValueError by {str(e)}") + if self.comm.is_rank0(): + await self.master_rpc_manager.rpc_func(request, None, -1) + return cost_time = time.perf_counter() - s1 self.comm.debug_rank0(f"forward cost time: {cost_time:.4f}")