Skip to content

Commit a5cd815

Browse files
fix: serialize streamable-http MCP requests per session (#2682)
1 parent 710449c commit a5cd815

File tree

2 files changed

+96
-8
lines changed

2 files changed

+96
-8
lines changed

src/agents/mcp/server.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def __init__(
336336
self.session: ClientSession | None = None
337337
self.exit_stack: AsyncExitStack = AsyncExitStack()
338338
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
339+
self._request_lock: asyncio.Lock = asyncio.Lock()
339340
self.cache_tools_list = cache_tools_list
340341
self.server_initialize_result: InitializeResult | None = None
341342

@@ -349,6 +350,13 @@ def __init__(
349350
self._tools_list: list[MCPTool] | None = None
350351

351352
self.tool_filter = tool_filter
353+
self._serialize_session_requests = False
354+
355+
async def _maybe_serialize_request(self, func: Callable[[], Awaitable[T]]) -> T:
356+
if not self._serialize_session_requests:
357+
return await func()
358+
async with self._request_lock:
359+
return await func()
352360

353361
async def _apply_tool_filter(
354362
self,
@@ -573,7 +581,9 @@ async def list_tools(
573581
tools = self._tools_list
574582
else:
575583
# Fetch the tools from the server
576-
result = await self._run_with_retries(lambda: session.list_tools())
584+
result = await self._run_with_retries(
585+
lambda: self._maybe_serialize_request(lambda: session.list_tools())
586+
)
577587
self._tools_list = result.tools
578588
self._cache_dirty = False
579589
tools = self._tools_list
@@ -609,9 +619,15 @@ async def call_tool(
609619
try:
610620
self._validate_required_parameters(tool_name=tool_name, arguments=arguments)
611621
if meta is None:
612-
return await self._run_with_retries(lambda: session.call_tool(tool_name, arguments))
622+
return await self._run_with_retries(
623+
lambda: self._maybe_serialize_request(
624+
lambda: session.call_tool(tool_name, arguments)
625+
)
626+
)
613627
return await self._run_with_retries(
614-
lambda: session.call_tool(tool_name, arguments, meta=meta)
628+
lambda: self._maybe_serialize_request(
629+
lambda: session.call_tool(tool_name, arguments, meta=meta)
630+
)
615631
)
616632
except httpx.HTTPStatusError as e:
617633
status_code = e.response.status_code
@@ -665,17 +681,19 @@ async def list_prompts(
665681
"""List the prompts available on the server."""
666682
if not self.session:
667683
raise UserError("Server not initialized. Make sure you call `connect()` first.")
668-
669-
return await self.session.list_prompts()
684+
session = self.session
685+
assert session is not None
686+
return await self._maybe_serialize_request(lambda: session.list_prompts())
670687

671688
async def get_prompt(
672689
self, name: str, arguments: dict[str, Any] | None = None
673690
) -> GetPromptResult:
674691
"""Get a specific prompt from the server."""
675692
if not self.session:
676693
raise UserError("Server not initialized. Make sure you call `connect()` first.")
677-
678-
return await self.session.get_prompt(name, arguments)
694+
session = self.session
695+
assert session is not None
696+
return await self._maybe_serialize_request(lambda: session.get_prompt(name, arguments))
679697

680698
async def cleanup(self):
681699
"""Cleanup the server."""
@@ -1084,6 +1102,7 @@ def __init__(
10841102

10851103
self.params = params
10861104
self._name = name or f"streamable_http: {self.params['url']}"
1105+
self._serialize_session_requests = True
10871106

10881107
def create_streams(
10891108
self,

tests/mcp/test_client_session_retries.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import cast
23

34
import pytest
@@ -29,14 +30,15 @@ async def list_tools(self):
2930

3031

3132
class DummyServer(_MCPServerWithClientSession):
32-
def __init__(self, session: DummySession, retries: int):
33+
def __init__(self, session: DummySession, retries: int, *, serialize_requests: bool = False):
3334
super().__init__(
3435
cache_tools_list=False,
3536
client_session_timeout_seconds=None,
3637
max_retry_attempts=retries,
3738
retry_backoff_seconds_base=0,
3839
)
3940
self.session = cast(ClientSession, session)
41+
self._serialize_session_requests = serialize_requests
4042

4143
def create_streams(self):
4244
raise NotImplementedError
@@ -148,3 +150,70 @@ async def test_call_tool_rejects_non_object_arguments_before_remote_call():
148150
await server.call_tool("tool", cast(dict[str, object] | None, ["bad"]))
149151

150152
assert session.call_tool_attempts == 0
153+
154+
155+
class ConcurrentCancellationSession:
156+
def __init__(self):
157+
self._slow_task: asyncio.Task[CallToolResult] | None = None
158+
self._slow_started = asyncio.Event()
159+
160+
async def call_tool(self, tool_name, arguments, meta=None):
161+
if tool_name == "slow":
162+
self._slow_task = cast(asyncio.Task[CallToolResult], asyncio.current_task())
163+
self._slow_started.set()
164+
await asyncio.sleep(0.1)
165+
return CallToolResult(content=[])
166+
167+
await self._slow_started.wait()
168+
assert self._slow_task is not None
169+
self._slow_task.cancel()
170+
raise RuntimeError("synthetic request failure")
171+
172+
async def list_tools(self):
173+
return ListToolsResult(tools=[MCPTool(name="tool", inputSchema={})])
174+
175+
async def list_prompts(self):
176+
await self._slow_started.wait()
177+
assert self._slow_task is not None
178+
self._slow_task.cancel()
179+
raise RuntimeError("synthetic request failure")
180+
181+
async def get_prompt(self, name, arguments=None):
182+
await self._slow_started.wait()
183+
assert self._slow_task is not None
184+
self._slow_task.cancel()
185+
raise RuntimeError("synthetic request failure")
186+
187+
188+
@pytest.mark.asyncio
189+
async def test_serialized_session_requests_prevent_sibling_cancellation():
190+
session = ConcurrentCancellationSession()
191+
server = DummyServer(session=cast(DummySession, session), retries=0, serialize_requests=True)
192+
193+
results = await asyncio.gather(
194+
server.call_tool("slow", None),
195+
server.call_tool("fail", None),
196+
return_exceptions=True,
197+
)
198+
199+
assert isinstance(results[0], CallToolResult)
200+
assert isinstance(results[1], RuntimeError)
201+
202+
203+
@pytest.mark.asyncio
204+
@pytest.mark.parametrize("prompt_method", ["list_prompts", "get_prompt"])
205+
async def test_serialized_prompt_requests_prevent_tool_cancellation(prompt_method: str):
206+
session = ConcurrentCancellationSession()
207+
server = DummyServer(session=cast(DummySession, session), retries=0, serialize_requests=True)
208+
209+
prompt_request = (
210+
server.list_prompts() if prompt_method == "list_prompts" else server.get_prompt("prompt")
211+
)
212+
results = await asyncio.gather(
213+
server.call_tool("slow", None),
214+
prompt_request,
215+
return_exceptions=True,
216+
)
217+
218+
assert isinstance(results[0], CallToolResult)
219+
assert isinstance(results[1], RuntimeError)

0 commit comments

Comments
 (0)