|
| 1 | +import asyncio |
1 | 2 | from typing import cast |
2 | 3 |
|
3 | 4 | import pytest |
@@ -29,14 +30,15 @@ async def list_tools(self): |
29 | 30 |
|
30 | 31 |
|
31 | 32 | class DummyServer(_MCPServerWithClientSession): |
32 | | - def __init__(self, session: DummySession, retries: int): |
| 33 | + def __init__(self, session: DummySession, retries: int, *, serialize_requests: bool = False): |
33 | 34 | super().__init__( |
34 | 35 | cache_tools_list=False, |
35 | 36 | client_session_timeout_seconds=None, |
36 | 37 | max_retry_attempts=retries, |
37 | 38 | retry_backoff_seconds_base=0, |
38 | 39 | ) |
39 | 40 | self.session = cast(ClientSession, session) |
| 41 | + self._serialize_session_requests = serialize_requests |
40 | 42 |
|
41 | 43 | def create_streams(self): |
42 | 44 | raise NotImplementedError |
@@ -148,3 +150,70 @@ async def test_call_tool_rejects_non_object_arguments_before_remote_call(): |
148 | 150 | await server.call_tool("tool", cast(dict[str, object] | None, ["bad"])) |
149 | 151 |
|
150 | 152 | assert session.call_tool_attempts == 0 |
| 153 | + |
| 154 | + |
| 155 | +class ConcurrentCancellationSession: |
| 156 | + def __init__(self): |
| 157 | + self._slow_task: asyncio.Task[CallToolResult] | None = None |
| 158 | + self._slow_started = asyncio.Event() |
| 159 | + |
| 160 | + async def call_tool(self, tool_name, arguments, meta=None): |
| 161 | + if tool_name == "slow": |
| 162 | + self._slow_task = cast(asyncio.Task[CallToolResult], asyncio.current_task()) |
| 163 | + self._slow_started.set() |
| 164 | + await asyncio.sleep(0.1) |
| 165 | + return CallToolResult(content=[]) |
| 166 | + |
| 167 | + await self._slow_started.wait() |
| 168 | + assert self._slow_task is not None |
| 169 | + self._slow_task.cancel() |
| 170 | + raise RuntimeError("synthetic request failure") |
| 171 | + |
| 172 | + async def list_tools(self): |
| 173 | + return ListToolsResult(tools=[MCPTool(name="tool", inputSchema={})]) |
| 174 | + |
| 175 | + async def list_prompts(self): |
| 176 | + await self._slow_started.wait() |
| 177 | + assert self._slow_task is not None |
| 178 | + self._slow_task.cancel() |
| 179 | + raise RuntimeError("synthetic request failure") |
| 180 | + |
| 181 | + async def get_prompt(self, name, arguments=None): |
| 182 | + await self._slow_started.wait() |
| 183 | + assert self._slow_task is not None |
| 184 | + self._slow_task.cancel() |
| 185 | + raise RuntimeError("synthetic request failure") |
| 186 | + |
| 187 | + |
| 188 | +@pytest.mark.asyncio |
| 189 | +async def test_serialized_session_requests_prevent_sibling_cancellation(): |
| 190 | + session = ConcurrentCancellationSession() |
| 191 | + server = DummyServer(session=cast(DummySession, session), retries=0, serialize_requests=True) |
| 192 | + |
| 193 | + results = await asyncio.gather( |
| 194 | + server.call_tool("slow", None), |
| 195 | + server.call_tool("fail", None), |
| 196 | + return_exceptions=True, |
| 197 | + ) |
| 198 | + |
| 199 | + assert isinstance(results[0], CallToolResult) |
| 200 | + assert isinstance(results[1], RuntimeError) |
| 201 | + |
| 202 | + |
| 203 | +@pytest.mark.asyncio |
| 204 | +@pytest.mark.parametrize("prompt_method", ["list_prompts", "get_prompt"]) |
| 205 | +async def test_serialized_prompt_requests_prevent_tool_cancellation(prompt_method: str): |
| 206 | + session = ConcurrentCancellationSession() |
| 207 | + server = DummyServer(session=cast(DummySession, session), retries=0, serialize_requests=True) |
| 208 | + |
| 209 | + prompt_request = ( |
| 210 | + server.list_prompts() if prompt_method == "list_prompts" else server.get_prompt("prompt") |
| 211 | + ) |
| 212 | + results = await asyncio.gather( |
| 213 | + server.call_tool("slow", None), |
| 214 | + prompt_request, |
| 215 | + return_exceptions=True, |
| 216 | + ) |
| 217 | + |
| 218 | + assert isinstance(results[0], CallToolResult) |
| 219 | + assert isinstance(results[1], RuntimeError) |
0 commit comments