Skip to content

Commit 710449c

Browse files
fix: handle inner MCP tool cancellations as tool errors (#2681)
1 parent 8c5c650 commit 710449c

File tree

2 files changed

+218
-5
lines changed

2 files changed

+218
-5
lines changed

src/agents/mcp/util.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import copy
45
import functools
56
import inspect
@@ -359,10 +360,28 @@ async def invoke_mcp_tool(
359360
try:
360361
resolved_meta = await cls._resolve_meta(server, context, tool.name, json_data)
361362
merged_meta = cls._merge_mcp_meta(resolved_meta, meta)
362-
if merged_meta is None:
363-
result = await server.call_tool(tool.name, json_data)
364-
else:
365-
result = await server.call_tool(tool.name, json_data, meta=merged_meta)
363+
call_task = asyncio.create_task(
364+
server.call_tool(tool.name, json_data)
365+
if merged_meta is None
366+
else server.call_tool(tool.name, json_data, meta=merged_meta)
367+
)
368+
try:
369+
done, _ = await asyncio.wait({call_task}, return_when=asyncio.FIRST_COMPLETED)
370+
finished_task = done.pop()
371+
if finished_task.cancelled():
372+
raise UserError(
373+
f"Failed to call tool '{tool.name}' on MCP server '{server.name}': "
374+
"tool execution was cancelled."
375+
)
376+
result = finished_task.result()
377+
except asyncio.CancelledError:
378+
if not call_task.done():
379+
call_task.cancel()
380+
try:
381+
await call_task
382+
except (asyncio.CancelledError, Exception):
383+
pass
384+
raise
366385
except UserError:
367386
# Re-raise UserError as-is (it already has a good message)
368387
raise

tests/mcp/test_mcp_util.py

Lines changed: 195 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import dataclasses
23
import json
34
import logging
@@ -9,7 +10,7 @@
910
from pydantic import BaseModel, TypeAdapter
1011

1112
from agents import Agent, FunctionTool, RunContextWrapper, default_tool_error_function
12-
from agents.exceptions import AgentsException, ModelBehaviorError
13+
from agents.exceptions import AgentsException, ModelBehaviorError, UserError
1314
from agents.mcp import MCPServer, MCPUtil
1415
from agents.tool_context import ToolContext
1516

@@ -175,6 +176,46 @@ async def call_tool(
175176
raise Exception("Crash!")
176177

177178

179+
class CancelledFakeMCPServer(FakeMCPServer):
180+
async def call_tool(
181+
self,
182+
tool_name: str,
183+
arguments: dict[str, Any] | None,
184+
meta: dict[str, Any] | None = None,
185+
):
186+
raise asyncio.CancelledError("synthetic mcp cancel")
187+
188+
189+
class SlowFakeMCPServer(FakeMCPServer):
190+
async def call_tool(
191+
self,
192+
tool_name: str,
193+
arguments: dict[str, Any] | None,
194+
meta: dict[str, Any] | None = None,
195+
):
196+
await asyncio.sleep(60)
197+
return await super().call_tool(tool_name, arguments, meta=meta)
198+
199+
200+
class CleanupOnCancelFakeMCPServer(FakeMCPServer):
201+
def __init__(self, cleanup_finished: asyncio.Event):
202+
super().__init__()
203+
self.cleanup_finished = cleanup_finished
204+
205+
async def call_tool(
206+
self,
207+
tool_name: str,
208+
arguments: dict[str, Any] | None,
209+
meta: dict[str, Any] | None = None,
210+
):
211+
try:
212+
await asyncio.sleep(60)
213+
except asyncio.CancelledError:
214+
await asyncio.sleep(0.05)
215+
self.cleanup_finished.set()
216+
raise
217+
218+
178219
@pytest.mark.asyncio
179220
async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixture):
180221
caplog.set_level(logging.DEBUG)
@@ -192,6 +233,159 @@ async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixtur
192233
assert "Error invoking MCP tool test_tool_1" in caplog.text
193234

194235

236+
@pytest.mark.asyncio
237+
async def test_mcp_tool_inner_cancellation_becomes_tool_error():
238+
server = CancelledFakeMCPServer()
239+
server.add_tool("cancel_tool", {})
240+
241+
ctx = RunContextWrapper(context=None)
242+
tool = MCPTool(name="cancel_tool", inputSchema={})
243+
244+
with pytest.raises(UserError, match="tool execution was cancelled"):
245+
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "{}")
246+
247+
agent = Agent(name="test-agent")
248+
function_tool = MCPUtil.to_function_tool(
249+
tool, server, convert_schemas_to_strict=False, agent=agent
250+
)
251+
tool_context = ToolContext(
252+
context=None,
253+
tool_name="cancel_tool",
254+
tool_call_id="test_call_cancelled",
255+
tool_arguments="{}",
256+
)
257+
258+
result = await function_tool.on_invoke_tool(tool_context, "{}")
259+
assert isinstance(result, str)
260+
assert "tool execution was cancelled" in result
261+
262+
263+
@pytest.mark.asyncio
264+
async def test_mcp_tool_inner_cancellation_still_becomes_tool_error_with_prior_cancel_state():
265+
current_task = asyncio.current_task()
266+
assert current_task is not None
267+
268+
current_task.cancel()
269+
with pytest.raises(asyncio.CancelledError):
270+
await asyncio.sleep(0)
271+
272+
server = CancelledFakeMCPServer()
273+
server.add_tool("cancel_tool", {})
274+
275+
ctx = RunContextWrapper(context=None)
276+
tool = MCPTool(name="cancel_tool", inputSchema={})
277+
278+
with pytest.raises(UserError, match="tool execution was cancelled"):
279+
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "{}")
280+
281+
282+
@pytest.mark.asyncio
283+
async def test_mcp_tool_outer_cancellation_still_propagates():
284+
server = SlowFakeMCPServer()
285+
server.add_tool("slow_tool", {})
286+
287+
ctx = RunContextWrapper(context=None)
288+
tool = MCPTool(name="slow_tool", inputSchema={})
289+
290+
task = asyncio.create_task(MCPUtil.invoke_mcp_tool(server, tool, ctx, "{}"))
291+
await asyncio.sleep(0.05)
292+
task.cancel()
293+
294+
with pytest.raises(asyncio.CancelledError):
295+
await task
296+
297+
298+
@pytest.mark.asyncio
299+
async def test_mcp_tool_outer_cancellation_after_inner_completion_still_propagates(
300+
monkeypatch: pytest.MonkeyPatch,
301+
):
302+
server = FakeMCPServer()
303+
server.add_tool("fast_tool", {})
304+
305+
ctx = RunContextWrapper(context=None)
306+
tool = MCPTool(name="fast_tool", inputSchema={})
307+
308+
async def fake_wait(tasks, *, return_when):
309+
del return_when
310+
(task,) = tuple(tasks)
311+
await task
312+
raise asyncio.CancelledError("synthetic outer cancellation")
313+
314+
monkeypatch.setattr(asyncio, "wait", fake_wait)
315+
316+
with pytest.raises(asyncio.CancelledError):
317+
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "{}")
318+
319+
320+
@pytest.mark.asyncio
321+
async def test_mcp_tool_outer_cancellation_after_inner_exception_still_propagates(
322+
monkeypatch: pytest.MonkeyPatch,
323+
):
324+
server = CrashingFakeMCPServer()
325+
server.add_tool("boom_tool", {})
326+
327+
ctx = RunContextWrapper(context=None)
328+
tool = MCPTool(name="boom_tool", inputSchema={})
329+
330+
async def fake_wait(tasks, *, return_when):
331+
del return_when
332+
(task,) = tuple(tasks)
333+
try:
334+
await task
335+
except Exception:
336+
pass
337+
raise asyncio.CancelledError("synthetic outer cancellation")
338+
339+
monkeypatch.setattr(asyncio, "wait", fake_wait)
340+
341+
with pytest.raises(asyncio.CancelledError):
342+
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "{}")
343+
344+
345+
@pytest.mark.asyncio
346+
async def test_mcp_tool_outer_cancellation_after_inner_cancellation_still_propagates(
347+
monkeypatch: pytest.MonkeyPatch,
348+
):
349+
server = SlowFakeMCPServer()
350+
server.add_tool("slow_tool", {})
351+
352+
ctx = RunContextWrapper(context=None)
353+
tool = MCPTool(name="slow_tool", inputSchema={})
354+
355+
async def fake_wait(tasks, *, return_when):
356+
del return_when
357+
(task,) = tuple(tasks)
358+
task.cancel()
359+
with pytest.raises(asyncio.CancelledError):
360+
await task
361+
362+
raise asyncio.CancelledError("synthetic combined cancellation")
363+
364+
monkeypatch.setattr(asyncio, "wait", fake_wait)
365+
366+
with pytest.raises(asyncio.CancelledError):
367+
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "{}")
368+
369+
370+
@pytest.mark.asyncio
371+
async def test_mcp_tool_outer_cancellation_waits_for_inner_cleanup():
372+
cleanup_finished = asyncio.Event()
373+
server = CleanupOnCancelFakeMCPServer(cleanup_finished)
374+
server.add_tool("slow_tool", {})
375+
376+
ctx = RunContextWrapper(context=None)
377+
tool = MCPTool(name="slow_tool", inputSchema={})
378+
379+
task = asyncio.create_task(MCPUtil.invoke_mcp_tool(server, tool, ctx, "{}"))
380+
await asyncio.sleep(0.05)
381+
task.cancel()
382+
383+
with pytest.raises(asyncio.CancelledError):
384+
await task
385+
386+
assert cleanup_finished.is_set()
387+
388+
195389
@pytest.mark.asyncio
196390
async def test_mcp_invocation_mcp_error_reraises(caplog: pytest.LogCaptureFixture):
197391
"""Test that McpError from server.call_tool is re-raised so the FunctionTool failure

0 commit comments

Comments
 (0)