diff --git a/src/postgres_mcp/__init__.py b/src/postgres_mcp/__init__.py index a00e3497..fccd4e11 100644 --- a/src/postgres_mcp/__init__.py +++ b/src/postgres_mcp/__init__.py @@ -13,7 +13,11 @@ def main(): if sys.platform == "win32": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - asyncio.run(server.main()) + try: + asyncio.run(server.main()) + except KeyboardInterrupt: + # Handle Ctrl+C gracefully without printing a traceback + pass # Optionally expose other important items at package level diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index 6a825a1f..2b6e7cd8 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -1,10 +1,7 @@ # ruff: noqa: B008 import argparse -import asyncio import logging import os -import signal -import sys from enum import Enum from typing import Any from typing import List @@ -56,7 +53,6 @@ class AccessMode(str, Enum): # Global variables db_connection = DbConnPool() current_access_mode = AccessMode.UNRESTRICTED -shutdown_in_progress = False async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]: @@ -633,47 +629,24 @@ async def main(): "The MCP server will start but database operations will fail until a valid connection is established.", ) - # Set up proper shutdown handling + # Run the server with the selected transport, with proper cleanup on exit try: - loop = asyncio.get_running_loop() - signals = (signal.SIGTERM, signal.SIGINT) - for s in signals: - loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown(s))) - except NotImplementedError: - # Windows doesn't support signals properly - logger.warning("Signal handling not supported on Windows") - pass - - # Run the server with the selected transport (always async) - if args.transport == "stdio": - await mcp.run_stdio_async() - else: - # Update FastMCP settings based on command line arguments - mcp.settings.host = args.sse_host - mcp.settings.port = args.sse_port - await mcp.run_sse_async() - - -async def shutdown(sig=None): - """Clean shutdown of the server.""" - global shutdown_in_progress - - if shutdown_in_progress: - logger.warning("Forcing immediate exit") - # Use sys.exit instead of os._exit to allow for proper cleanup - sys.exit(1) - - shutdown_in_progress = True + if args.transport == "stdio": + await mcp.run_stdio_async() + else: + mcp.settings.host = args.sse_host + mcp.settings.port = args.sse_port + await mcp.run_sse_async() + finally: + # Clean up database connections on exit + await cleanup() - if sig: - logger.info(f"Received exit signal {sig.name}") - # Close database connections +async def cleanup(): + """Clean up resources on server shutdown.""" + logger.info("Shutting down server...") try: await db_connection.close() logger.info("Closed database connections") except Exception as e: logger.error(f"Error closing database connections: {e}") - - # Exit with appropriate status code - sys.exit(128 + sig if sig is not None else 0) diff --git a/tests/unit/test_access_mode.py b/tests/unit/test_access_mode.py index f7d3b803..c10d01e5 100644 --- a/tests/unit/test_access_mode.py +++ b/tests/unit/test_access_mode.py @@ -92,7 +92,7 @@ async def test_command_line_parsing(): patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()), - patch("postgres_mcp.server.shutdown", AsyncMock()), + patch("postgres_mcp.server.cleanup", AsyncMock()), ): # Reset the current_access_mode to UNRESTRICTED import postgres_mcp.server diff --git a/tests/unit/test_shutdown.py b/tests/unit/test_shutdown.py new file mode 100644 index 00000000..3ac94359 --- /dev/null +++ b/tests/unit/test_shutdown.py @@ -0,0 +1,85 @@ +import sys +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + + +@pytest.mark.asyncio +async def test_cleanup_closes_db_connection(): + """Test that cleanup properly closes database connections.""" + from postgres_mcp.server import cleanup + + mock_db = MagicMock() + mock_db.close = AsyncMock() + + with patch("postgres_mcp.server.db_connection", mock_db): + await cleanup() + mock_db.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_cleanup_handles_db_close_error(): + """Test that cleanup handles errors when closing database connections.""" + from postgres_mcp.server import cleanup + + mock_db = MagicMock() + mock_db.close = AsyncMock(side_effect=Exception("Connection error")) + + with patch("postgres_mcp.server.db_connection", mock_db): + # Should not raise, just log the error + await cleanup() + mock_db.close.assert_called_once() + + +@pytest.mark.asyncio +async def test_main_calls_cleanup_on_normal_exit(): + """Test that main() calls cleanup when transport exits normally.""" + from postgres_mcp.server import main + + original_argv = sys.argv + try: + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + ] + + mock_cleanup = AsyncMock() + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()), + patch("postgres_mcp.server.cleanup", mock_cleanup), + ): + await main() + mock_cleanup.assert_called_once() + finally: + sys.argv = original_argv + + +@pytest.mark.asyncio +async def test_main_calls_cleanup_on_exception(): + """Test that main() calls cleanup even when transport raises an exception.""" + from postgres_mcp.server import main + + original_argv = sys.argv + try: + sys.argv = [ + "postgres_mcp", + "postgresql://user:password@localhost/db", + ] + + mock_cleanup = AsyncMock() + + with ( + patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock(side_effect=Exception("Transport error"))), + patch("postgres_mcp.server.cleanup", mock_cleanup), + ): + with pytest.raises(Exception, match="Transport error"): + await main() + # Cleanup should still be called due to finally block + mock_cleanup.assert_called_once() + finally: + sys.argv = original_argv