diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index a75cfd764..8c1eb2414 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -16,7 +16,6 @@ from .win32 import ( create_windows_process, get_windows_executable_command, - terminate_windows_process, ) # Environment variables to inherit by default @@ -179,11 +178,10 @@ async def stdin_writer(): yield read_stream, write_stream finally: # Clean up process to prevent any dangling orphaned processes + # Unified cleanup across all platforms: simple terminate + stream cleanup + # Stream cleanup (PR #559) is the key to preventing hanging behavior try: - if sys.platform == "win32": - await terminate_windows_process(process) - else: - process.terminate() + process.terminate() except ProcessLookupError: # Process already exited, which is fine pass diff --git a/src/mcp/client/stdio/win32.py b/src/mcp/client/stdio/win32.py index 7246b9dec..55c3a94d0 100644 --- a/src/mcp/client/stdio/win32.py +++ b/src/mcp/client/stdio/win32.py @@ -161,22 +161,5 @@ async def create_windows_process( return FallbackProcess(popen_obj) -async def terminate_windows_process(process: Process | FallbackProcess): - """ - Terminate a Windows process. - - Note: On Windows, terminating a process with process.terminate() doesn't - always guarantee immediate process termination. - So we give it 2s to exit, or we call process.kill() - which sends a SIGKILL equivalent signal. - - Args: - process: The process to terminate - """ - try: - process.terminate() - with anyio.fail_after(2.0): - await process.wait() - except TimeoutError: - # Force kill if it doesn't terminate - process.kill() +# Windows-specific process termination function removed +# Unified cleanup now uses simple process.terminate() + stream cleanup across all platforms diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index c66a16ab9..58b9f04ac 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,4 +1,6 @@ import shutil +import sys +import time import pytest @@ -90,3 +92,93 @@ async def test_stdio_client_nonexistent_command(): or "not found" in error_message.lower() or "cannot find the file" in error_message.lower() # Windows error message ) + + +@pytest.mark.anyio +async def test_stdio_client_universal_timeout(): + """ + Test that stdio_client completes cleanup within reasonable time + even when connected to processes that exit slowly. + """ + + # Use a simple sleep command that's available on all platforms + # This simulates a process that takes time to terminate + if sys.platform == "win32": + # Windows: use ping with timeout to simulate a running process + server_params = StdioServerParameters( + command="ping", + args=["127.0.0.1", "-n", "10"], # Ping 10 times, takes ~10 seconds + ) + else: + # Unix: use sleep command + server_params = StdioServerParameters( + command="sleep", + args=["10"], # Sleep for 10 seconds + ) + + start_time = time.time() + + try: + async with stdio_client(server_params) as (read_stream, write_stream): + # Immediately exit - this triggers cleanup while process is still running + pass + + end_time = time.time() + elapsed = end_time - start_time + + # Key assertion: Should complete quickly due to timeout mechanism + # Before PR #555, Unix systems might hang for the full 10 seconds + # After PR #555, all platforms should complete within ~2-3 seconds + assert elapsed < 5.0, ( + f"stdio_client cleanup took {elapsed:.1f} seconds, expected < 5.0 seconds. " + f"This suggests the timeout mechanism may not be working properly." + ) + + except Exception as e: + end_time = time.time() + elapsed = end_time - start_time + print(f"❌ Test failed after {elapsed:.1f} seconds: {e}") + raise + + +@pytest.mark.anyio +async def test_stdio_client_immediate_completion(): + """ + Test that stdio_client doesn't introduce unnecessary delays + when processes exit normally and quickly. + + This ensures PR #555's timeout mechanism doesn't slow down normal operation. + """ + + # Use a command that exits immediately + if sys.platform == "win32": + server_params = StdioServerParameters( + command="cmd", + args=["/c", "echo", "hello"], # Windows: echo and exit + ) + else: + server_params = StdioServerParameters( + command="echo", + args=["hello"], # Unix: echo and exit + ) + + start_time = time.time() + + try: + async with stdio_client(server_params) as (read_stream, write_stream): + pass + + end_time = time.time() + elapsed = end_time - start_time + + # Should complete very quickly when process exits normally + assert elapsed < 2.0, ( + f"stdio_client took {elapsed:.1f} seconds for fast-exiting process, " + f"expected < 2.0 seconds. Timeout mechanism may be introducing delays." + ) + + except Exception as e: + end_time = time.time() + elapsed = end_time - start_time + print(f"❌ Test failed after {elapsed:.1f} seconds: {e}") + raise