Skip to content

Commit dfa7c6a

Browse files
committed
Fix: SseServerTransport.handle_sse never ends and can lead to memory leaks
1 parent b4c7db6 commit dfa7c6a

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

examples/servers/simple-tool/mcp_simple_tool/server.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,24 @@ async def list_tools() -> list[types.Tool]:
6060
if transport == "sse":
6161
from mcp.server.sse import SseServerTransport
6262
from starlette.applications import Starlette
63+
from starlette.requests import Request
64+
from starlette.responses import Response
6365
from starlette.routing import Mount, Route
6466

6567
sse = SseServerTransport("/messages/")
6668

67-
async def handle_sse(request):
68-
async with sse.connect_sse(
69-
request.scope, request.receive, request._send
70-
) as streams:
71-
await app.run(
72-
streams[0], streams[1], app.create_initialization_options()
73-
)
69+
async def handle_sse(request: Request):
70+
with anyio.CancelScope() as cancel_scope:
71+
async with sse.connect_sse(
72+
request.scope,
73+
request.receive,
74+
request._send,
75+
lambda: cancel_scope.cancel(),
76+
) as streams:
77+
await app.run(
78+
streams[0], streams[1], app.create_initialization_options()
79+
)
80+
return Response(status_code=200)
7481

7582
starlette_app = Starlette(
7683
debug=True,

src/mcp/server/sse.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ async def handle_sse(request):
3232
"""
3333

3434
import logging
35+
from collections.abc import Callable
3536
from contextlib import asynccontextmanager
3637
from typing import Any
3738
from urllib.parse import quote
@@ -41,6 +42,7 @@ async def handle_sse(request):
4142
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4243
from pydantic import ValidationError
4344
from sse_starlette import EventSourceResponse
45+
from starlette.background import BackgroundTask
4446
from starlette.requests import Request
4547
from starlette.responses import Response
4648
from starlette.types import Receive, Scope, Send
@@ -79,7 +81,13 @@ def __init__(self, endpoint: str) -> None:
7981
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
8082

8183
@asynccontextmanager
82-
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
84+
async def connect_sse(
85+
self,
86+
scope: Scope,
87+
receive: Receive,
88+
send: Send,
89+
callback: Callable[[], None] | None = None,
90+
):
8391
if scope["type"] != "http":
8492
logger.error("connect_sse received non-HTTP request")
8593
raise ValueError("connect_sse can only handle HTTP requests")
@@ -120,9 +128,19 @@ async def sse_writer():
120128
}
121129
)
122130

131+
async def _remove_stream_writer() -> None:
132+
await read_stream_writer.aclose()
133+
await write_stream_reader.aclose()
134+
del self._read_stream_writers[session_id]
135+
if callback:
136+
callback()
137+
logger.debug(f"Closed SSE session with ID: {session_id}")
138+
123139
async with anyio.create_task_group() as tg:
124140
response = EventSourceResponse(
125-
content=sse_stream_reader, data_sender_callable=sse_writer
141+
content=sse_stream_reader,
142+
data_sender_callable=sse_writer,
143+
background=BackgroundTask(_remove_stream_writer),
126144
)
127145
logger.debug("Starting SSE response task")
128146
tg.start_soon(response, scope, receive, send)

0 commit comments

Comments
 (0)