diff --git a/common/pool.py b/common/pool.py index 93fe6c9..aa3761b 100644 --- a/common/pool.py +++ b/common/pool.py @@ -82,7 +82,23 @@ def _register_again(): threading.Timer(delay_time, _register_again).start() async def async_unregister(self, s: socket.socket): - await asyncio.get_event_loop().run_in_executor(self.executor, self.unregister, s) + """添加超时机制的异步取消注册""" + try: + await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor(self.executor, self.unregister, s), + timeout=5 # 5秒超时 + ) + except asyncio.TimeoutError: + LoggerFactory.get_logger().error(f"Timeout unregistering socket {s}") + # 强制从跟踪字典中移除 + if s.fileno() in self.fileno_to_client: + self.fileno_to_client.pop(s.fileno()) + if s in self.socket_to_register_lock: + self.socket_to_register_lock.pop(s) + if s in self.socket_to_recv_lock: + self.socket_to_recv_lock.pop(s) + if s in self.waiting_register_socket: + self.waiting_register_socket.remove(s) def unregister(self, s: socket.socket): if s not in self.socket_to_register_lock: @@ -103,8 +119,10 @@ def unregister(self, s: socket.socket): pass except OSError: LoggerFactory.get_logger().error(traceback.format_exc()) - self.socket_to_register_lock.pop(s) - self.socket_to_recv_lock.pop(s) + if s in self.socket_to_register_lock: + self.socket_to_register_lock.pop(s) + if s in self.socket_to_recv_lock: + self.socket_to_recv_lock.pop(s) def run(self): while True: diff --git a/constant/system_constant.py b/constant/system_constant.py index 225c305..a5161a7 100644 --- a/constant/system_constant.py +++ b/constant/system_constant.py @@ -15,6 +15,6 @@ class SystemConstant: COOKIE_EXPIRE_SECONDS = 3600 * 24 - VERSION = '1.1.49' + VERSION = '1.1.52' GITHUB = 'https://github.com/sazima/proxynt' diff --git a/server/tcp_forward_client.py b/server/tcp_forward_client.py index 6889fc2..cd88f5b 100644 --- a/server/tcp_forward_client.py +++ b/server/tcp_forward_client.py @@ -200,16 +200,49 @@ async def send_to_socket(self, uid: bytes, message: bytes): if LoggerFactory.get_logger().isEnabledFor(logging.DEBUG): LoggerFactory.get_logger().debug(f'send to socket uid: {uid}, len: {len(message)}') try: - await asyncio.get_event_loop().sock_sendall(socket_client, message) - except OSError: - LoggerFactory.get_logger().warn(f'{uid} os error') - pass + # 添加超时机制 + await asyncio.wait_for(asyncio.get_event_loop().sock_sendall(socket_client, message), timeout=30) + except asyncio.TimeoutError: + LoggerFactory.get_logger().warn(f"Socket send timeout for {uid}, closing connection") + # 使用 ensure_future 替代 create_task,兼容 Python 3.6 + asyncio.ensure_future(self.close_connection_async(connection)) + return + except OSError as e: + LoggerFactory.get_logger().warn(f'{uid} os error: {e}') + # 使用 ensure_future 替代 create_task,兼容 Python 3.6 + asyncio.ensure_future(self.close_connection_async(connection)) + return if not message: - asyncio.get_event_loop().run_in_executor(None, self.close_connection, connection) - + # 使用异步方式关闭连接 + asyncio.ensure_future(self.close_connection_async(connection)) if LoggerFactory.get_logger().isEnabledFor(logging.DEBUG): LoggerFactory.get_logger().debug(f'send to socket cost time {time.time() - send_start_time}') + async def close_connection_async(self, connection: PublicSocketConnection): + """异步关闭连接,避免在事件循环中阻塞""" + try: + LoggerFactory.get_logger().info(f'async close {connection.uid}') + uid = connection.uid + if uid not in self.uid_to_connection: + return + # 从跟踪字典中移除 + self.uid_to_connection.pop(uid, None) + self.socket_to_connection.pop(connection.socket, None) + connection.socket_server.delete_client(connection) + # 确保在关闭前取消注册 + try: + await self.socket_event_loop.async_unregister(connection.socket) + except Exception as e: + LoggerFactory.get_logger().error(f'Error unregistering socket: {e}') + + # 关闭套接字 + try: + connection.socket.close() + except Exception as e: + LoggerFactory.get_logger().error(f'Error closing socket: {e}') + except Exception as e: + LoggerFactory.get_logger().error(f'close error {e}') + def close_connection(self, connection: PublicSocketConnection): try: LoggerFactory.get_logger().info(f'close {connection.uid}')