diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 90ad92e33..11daedc98 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -187,7 +187,6 @@ def __init__( self._receive_notification_type = receive_notification_type self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} - self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: @@ -232,45 +231,48 @@ async def send_request( ](1) self._response_streams[request_id] = response_stream - self._exit_stack.push_async_callback(lambda: response_stream.aclose()) - self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose()) - - jsonrpc_request = JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - **request.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - - # TODO: Support progress callbacks - - await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) - - # request read timeout takes precedence over session read timeout - timeout = None - if request_read_timeout_seconds is not None: - timeout = request_read_timeout_seconds.total_seconds() - elif self._session_read_timeout_seconds is not None: - timeout = self._session_read_timeout_seconds.total_seconds() - try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() - except TimeoutError: - raise McpError( - ErrorData( - code=httpx.codes.REQUEST_TIMEOUT, - message=( - f"Timed out while waiting for response to " - f"{request.__class__.__name__}. Waited " - f"{timeout} seconds." - ), - ) + jsonrpc_request = JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + **request.model_dump(by_alias=True, mode="json", exclude_none=True), ) - if isinstance(response_or_error, JSONRPCError): - raise McpError(response_or_error.error) - else: - return result_type.model_validate(response_or_error.result) + # TODO: Support progress callbacks + + await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) + + # request read timeout takes precedence over session read timeout + timeout = None + if request_read_timeout_seconds is not None: + timeout = request_read_timeout_seconds.total_seconds() + elif self._session_read_timeout_seconds is not None: + timeout = self._session_read_timeout_seconds.total_seconds() + + try: + with anyio.fail_after(timeout): + response_or_error = await response_stream_reader.receive() + except TimeoutError: + raise McpError( + ErrorData( + code=httpx.codes.REQUEST_TIMEOUT, + message=( + f"Timed out while waiting for response to " + f"{request.__class__.__name__}. Waited " + f"{timeout} seconds." + ), + ) + ) + + if isinstance(response_or_error, JSONRPCError): + raise McpError(response_or_error.error) + else: + return result_type.model_validate(response_or_error.result) + + finally: + self._response_streams.pop(request_id, None) + await response_stream.aclose() + await response_stream_reader.aclose() async def send_notification(self, notification: SendNotificationT) -> None: """ diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py new file mode 100644 index 000000000..990b3a89a --- /dev/null +++ b/tests/client/test_resource_cleanup.py @@ -0,0 +1,68 @@ +from unittest.mock import patch + +import anyio +import pytest + +from mcp.shared.session import BaseSession +from mcp.types import ( + ClientRequest, + EmptyResult, + PingRequest, +) + + +@pytest.mark.anyio +async def test_send_request_stream_cleanup(): + """ + Test that send_request properly cleans up streams when an exception occurs. + + This test mocks out most of the session functionality to focus on stream cleanup. + """ + + # Create a mock session with the minimal required functionality + class TestSession(BaseSession): + async def _send_response(self, request_id, response): + pass + + # Create streams + write_stream_send, write_stream_receive = anyio.create_memory_object_stream(1) + read_stream_send, read_stream_receive = anyio.create_memory_object_stream(1) + + # Create the session + session = TestSession( + read_stream_receive, + write_stream_send, + object, # Request type doesn't matter for this test + object, # Notification type doesn't matter for this test + ) + + # Create a test request + request = ClientRequest( + PingRequest( + method="ping", + ) + ) + + # Patch the _write_stream.send method to raise an exception + async def mock_send(*args, **kwargs): + raise RuntimeError("Simulated network error") + + # Record the response streams before the test + initial_stream_count = len(session._response_streams) + + # Run the test with the patched method + with patch.object(session._write_stream, "send", mock_send): + with pytest.raises(RuntimeError): + await session.send_request(request, EmptyResult) + + # Verify that no response streams were leaked + assert len(session._response_streams) == initial_stream_count, ( + f"Expected {initial_stream_count} response streams after request, " + f"but found {len(session._response_streams)}" + ) + + # Clean up + await write_stream_send.aclose() + await write_stream_receive.aclose() + await read_stream_send.aclose() + await read_stream_receive.aclose()