Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent server hang by adding timeout to network operations #34

Merged
merged 1 commit into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions common/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,23 @@ def _register_again():
threading.Timer(delay_time, _register_again).start()

async def async_unregister(self, s: socket.socket):
await asyncio.get_event_loop().run_in_executor(self.executor, self.unregister, s)
"""添加超时机制的异步取消注册"""
try:
await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(self.executor, self.unregister, s),
timeout=5 # 5秒超时
)
except asyncio.TimeoutError:
LoggerFactory.get_logger().error(f"Timeout unregistering socket {s}")
# 强制从跟踪字典中移除
if s.fileno() in self.fileno_to_client:
self.fileno_to_client.pop(s.fileno())
if s in self.socket_to_register_lock:
self.socket_to_register_lock.pop(s)
if s in self.socket_to_recv_lock:
self.socket_to_recv_lock.pop(s)
if s in self.waiting_register_socket:
self.waiting_register_socket.remove(s)

def unregister(self, s: socket.socket):
if s not in self.socket_to_register_lock:
Expand All @@ -103,8 +119,10 @@ def unregister(self, s: socket.socket):
pass
except OSError:
LoggerFactory.get_logger().error(traceback.format_exc())
self.socket_to_register_lock.pop(s)
self.socket_to_recv_lock.pop(s)
if s in self.socket_to_register_lock:
self.socket_to_register_lock.pop(s)
if s in self.socket_to_recv_lock:
self.socket_to_recv_lock.pop(s)

def run(self):
while True:
Expand Down
2 changes: 1 addition & 1 deletion constant/system_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ class SystemConstant:

COOKIE_EXPIRE_SECONDS = 3600 * 24

VERSION = '1.1.49'
VERSION = '1.1.52'

GITHUB = 'https://github.com/sazima/proxynt'
45 changes: 39 additions & 6 deletions server/tcp_forward_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,49 @@ async def send_to_socket(self, uid: bytes, message: bytes):
if LoggerFactory.get_logger().isEnabledFor(logging.DEBUG):
LoggerFactory.get_logger().debug(f'send to socket uid: {uid}, len: {len(message)}')
try:
await asyncio.get_event_loop().sock_sendall(socket_client, message)
except OSError:
LoggerFactory.get_logger().warn(f'{uid} os error')
pass
# 添加超时机制
await asyncio.wait_for(asyncio.get_event_loop().sock_sendall(socket_client, message), timeout=30)
except asyncio.TimeoutError:
LoggerFactory.get_logger().warn(f"Socket send timeout for {uid}, closing connection")
# 使用 ensure_future 替代 create_task,兼容 Python 3.6
asyncio.ensure_future(self.close_connection_async(connection))
return
except OSError as e:
LoggerFactory.get_logger().warn(f'{uid} os error: {e}')
# 使用 ensure_future 替代 create_task,兼容 Python 3.6
asyncio.ensure_future(self.close_connection_async(connection))
return
if not message:
asyncio.get_event_loop().run_in_executor(None, self.close_connection, connection)

# 使用异步方式关闭连接
asyncio.ensure_future(self.close_connection_async(connection))
if LoggerFactory.get_logger().isEnabledFor(logging.DEBUG):
LoggerFactory.get_logger().debug(f'send to socket cost time {time.time() - send_start_time}')

async def close_connection_async(self, connection: PublicSocketConnection):
"""异步关闭连接,避免在事件循环中阻塞"""
try:
LoggerFactory.get_logger().info(f'async close {connection.uid}')
uid = connection.uid
if uid not in self.uid_to_connection:
return
# 从跟踪字典中移除
self.uid_to_connection.pop(uid, None)
self.socket_to_connection.pop(connection.socket, None)
connection.socket_server.delete_client(connection)
# 确保在关闭前取消注册
try:
await self.socket_event_loop.async_unregister(connection.socket)
except Exception as e:
LoggerFactory.get_logger().error(f'Error unregistering socket: {e}')

# 关闭套接字
try:
connection.socket.close()
except Exception as e:
LoggerFactory.get_logger().error(f'Error closing socket: {e}')
except Exception as e:
LoggerFactory.get_logger().error(f'close error {e}')

def close_connection(self, connection: PublicSocketConnection):
try:
LoggerFactory.get_logger().info(f'close {connection.uid}')
Expand Down