From f7d33ad00eb8fb1311d1480762451d50bb917bc9 Mon Sep 17 00:00:00 2001 From: Hoa Lam Date: Thu, 24 Apr 2025 16:43:23 +0700 Subject: [PATCH 1/3] Fix handle sse disconnect --- src/mcp/server/sse.py | 10 ++++++++-- uv.lock | 27 +++++++++++++++------------ 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 0127753d..d503b307 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -43,7 +43,7 @@ async def handle_sse(request): from sse_starlette import EventSourceResponse from starlette.requests import Request from starlette.responses import Response -from starlette.types import Receive, Scope, Send +from starlette.types import Message, Receive, Scope, Send import mcp.types as types @@ -120,9 +120,15 @@ async def sse_writer(): } ) + async def handle_see_disconnect(message: Message) -> None: + logger.debug(f"Disconnect sse {session_id}") + del self._read_stream_writers[session_id] + async with anyio.create_task_group() as tg: response = EventSourceResponse( - content=sse_stream_reader, data_sender_callable=sse_writer + content=sse_stream_reader, + data_sender_callable=sse_writer, + client_close_handler_callable=handle_see_disconnect, # type: ignore ) logger.debug("Starting SSE response task") tg.start_soon(response, scope, receive, send) diff --git a/uv.lock b/uv.lock index 7ff1a3ea..48911fbe 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" [options] @@ -23,17 +24,17 @@ wheels = [ [[package]] name = "anyio" -version = "4.5.0" +version = "4.7.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "idna" }, { name = "sniffio" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a0/44/66874c5256e9fbc30103b31927fd9341c8da6ccafd4721b2b3e81e6ef176/anyio-4.5.0.tar.gz", hash = "sha256:c5a275fe5ca0afd788001f58fca1e69e29ce706d746e317d660e21f70c530ef9", size = 169376 } +sdist = { url = "https://files.pythonhosted.org/packages/f6/40/318e58f669b1a9e00f5c4453910682e2d9dd594334539c7b7817dabb765f/anyio-4.7.0.tar.gz", hash = "sha256:2f834749c602966b7d456a7567cafcb309f96482b5081d14ac93ccd457f9dd48", size = 177076 } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/68/f9e9bf6324c46e6b8396610aef90ad423ec3e18c9079547ceafea3dce0ec/anyio-4.5.0-py3-none-any.whl", hash = "sha256:fdeb095b7cc5a5563175eedd926ec4ae55413bb4be5770c424af0ba46ccb4a78", size = 89250 }, + { url = "https://files.pythonhosted.org/packages/a0/7a/4daaf3b6c08ad7ceffea4634ec206faeff697526421c20f07628c7372156/anyio-4.7.0-py3-none-any.whl", hash = "sha256:ea60c3723ab42ba6fff7e8ccb0488c898ec538ff4df1f1d5e642c3601d07e352", size = 93052 }, ] [[package]] @@ -78,7 +79,7 @@ name = "click" version = "8.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/45/2b/7ebad1e59a99207d417c0784f7fb67893465eef84b5b47c788324f1b4095/click-8.1.0.tar.gz", hash = "sha256:977c213473c7665d3aa092b41ff12063227751c41d7b17165013e10069cc5cd2", size = 329986 } wheels = [ @@ -191,7 +192,7 @@ wheels = [ [[package]] name = "mcp" -version = "1.3.0.dev0" +version = "1.3.0" source = { editable = "." } dependencies = [ { name = "anyio" }, @@ -237,6 +238,7 @@ requires-dist = [ { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.4" }, { name = "uvicorn", specifier = ">=0.23.1" }, ] +provides-extras = ["rich", "cli"] [package.metadata.requires-dev] dev = [ @@ -647,26 +649,27 @@ wheels = [ [[package]] name = "sse-starlette" -version = "1.6.1" +version = "2.3.3" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "anyio" }, { name = "starlette" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/40/88/0af7f586894cfe61bd212f33e571785c4570085711b24fb7445425a5eeb0/sse-starlette-1.6.1.tar.gz", hash = "sha256:6208af2bd7d0887c92f1379da14bd1f4db56bd1274cc5d36670c683d2aa1de6a", size = 14555 } +sdist = { url = "https://files.pythonhosted.org/packages/86/35/7d8d94eb0474352d55f60f80ebc30f7e59441a29e18886a6425f0bccd0d3/sse_starlette-2.3.3.tar.gz", hash = "sha256:fdd47c254aad42907cfd5c5b83e2282be15be6c51197bf1a9b70b8e990522072", size = 17499 } wheels = [ - { url = "https://files.pythonhosted.org/packages/5e/f7/499e5d0c181a52a205d5b0982fd71cf162d1e070c97dca90c60520bbf8bf/sse_starlette-1.6.1-py3-none-any.whl", hash = "sha256:d8f18f1c633e355afe61cc5e9c92eea85badcb8b2d56ec8cfb0a006994aa55da", size = 9553 }, + { url = "https://files.pythonhosted.org/packages/5d/20/52fdb5ebb158294b0adb5662235dd396fc7e47aa31c293978d8d8942095a/sse_starlette-2.3.3-py3-none-any.whl", hash = "sha256:8b0a0ced04a329ff7341b01007580dd8cf71331cc21c0ccea677d500618da1e0", size = 10235 }, ] [[package]] name = "starlette" -version = "0.27.0" +version = "0.41.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/06/68/559bed5484e746f1ab2ebbe22312f2c25ec62e4b534916d41a8c21147bf8/starlette-0.27.0.tar.gz", hash = "sha256:6a6b0d042acb8d469a01eba54e9cda6cbd24ac602c4cd016723117d6a7e73b75", size = 51394 } +sdist = { url = "https://files.pythonhosted.org/packages/1a/4c/9b5764bd22eec91c4039ef4c55334e9187085da2d8a2df7bd570869aae18/starlette-0.41.3.tar.gz", hash = "sha256:0e4ab3d16522a255be6b28260b938eae2482f98ce5cc934cb08dce8dc3ba5835", size = 2574159 } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/f8/e2cca22387965584a409795913b774235752be4176d276714e15e1a58884/starlette-0.27.0-py3-none-any.whl", hash = "sha256:918416370e846586541235ccd38a474c08b80443ed31c578a418e2209b3eef91", size = 66978 }, + { url = "https://files.pythonhosted.org/packages/96/00/2b325970b3060c7cecebab6d295afe763365822b1306a12eeab198f74323/starlette-0.41.3-py3-none-any.whl", hash = "sha256:44cedb2b7c77a9de33a8b74b2b90e9f50d11fcf25d8270ea525ad71a25374ff7", size = 73225 }, ] [[package]] From e8e0491e24c0507537abbeae4d4893cd12f4ba51 Mon Sep 17 00:00:00 2001 From: Hoa Lam Date: Fri, 25 Apr 2025 09:01:25 +0700 Subject: [PATCH 2/3] Add close read write stream --- src/mcp/server/sse.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index d503b307..08e15efc 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -123,12 +123,16 @@ async def sse_writer(): async def handle_see_disconnect(message: Message) -> None: logger.debug(f"Disconnect sse {session_id}") del self._read_stream_writers[session_id] + await read_stream.aclose() + await read_stream_writer.aclose() + await write_stream.aclose() + await write_stream_reader.aclose() async with anyio.create_task_group() as tg: response = EventSourceResponse( content=sse_stream_reader, data_sender_callable=sse_writer, - client_close_handler_callable=handle_see_disconnect, # type: ignore + client_close_handler_callable=handle_see_disconnect # type: ignore ) logger.debug("Starting SSE response task") tg.start_soon(response, scope, receive, send) From c7ef8a201b424113799892e2d9902952bb3efc22 Mon Sep 17 00:00:00 2001 From: Hoa Lam Date: Sat, 26 Apr 2025 19:58:42 +0700 Subject: [PATCH 3/3] Add unittest sse disconnect --- tests/server/test_sse_disconnect.py | 54 +++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tests/server/test_sse_disconnect.py diff --git a/tests/server/test_sse_disconnect.py b/tests/server/test_sse_disconnect.py new file mode 100644 index 00000000..3dc6e005 --- /dev/null +++ b/tests/server/test_sse_disconnect.py @@ -0,0 +1,54 @@ +import asyncio +from uuid import UUID + +import pytest +from starlette.types import Message, Scope + +from mcp.server.sse import SseServerTransport + + +@pytest.mark.anyio +async def test_sse_disconnect_handle(): + transport = SseServerTransport(endpoint="/sse") + # Create a minimal ASGI scope for an HTTP GET request + scope: Scope = { + "type": "http", + "method": "GET", + "path": "/sse", + "headers": [], + } + send_disconnect = False + + # Dummy receive and send functions + async def receive() -> dict: + nonlocal send_disconnect + if not send_disconnect: + send_disconnect = True + return {"type": "http.request"} + else: + return {"type": "http.disconnect"} + + async def send(message: Message) -> None: + await asyncio.sleep(0) + + # Run the connect_sse context manager + async with transport.connect_sse(scope, receive, send) as ( + read_stream, + write_stream, + ): + # Assert that streams are provided + assert read_stream is not None + assert write_stream is not None + + # There should be exactly one session + assert len(transport._read_stream_writers) == 1 + # Check that the session key is a UUID + session_id = next(iter(transport._read_stream_writers.keys())) + assert isinstance(session_id, UUID) + + # Check that the writer is still open + writer = transport._read_stream_writers[session_id] + assert writer is not None + + # After context exits, session should be cleaned up + assert len(transport._read_stream_writers) == 0