1+ import asyncio
12import dataclasses
23import json
34import logging
910from pydantic import BaseModel , TypeAdapter
1011
1112from agents import Agent , FunctionTool , RunContextWrapper , default_tool_error_function
12- from agents .exceptions import AgentsException , ModelBehaviorError
13+ from agents .exceptions import AgentsException , ModelBehaviorError , UserError
1314from agents .mcp import MCPServer , MCPUtil
1415from agents .tool_context import ToolContext
1516
@@ -175,6 +176,46 @@ async def call_tool(
175176 raise Exception ("Crash!" )
176177
177178
179+ class CancelledFakeMCPServer (FakeMCPServer ):
180+ async def call_tool (
181+ self ,
182+ tool_name : str ,
183+ arguments : dict [str , Any ] | None ,
184+ meta : dict [str , Any ] | None = None ,
185+ ):
186+ raise asyncio .CancelledError ("synthetic mcp cancel" )
187+
188+
189+ class SlowFakeMCPServer (FakeMCPServer ):
190+ async def call_tool (
191+ self ,
192+ tool_name : str ,
193+ arguments : dict [str , Any ] | None ,
194+ meta : dict [str , Any ] | None = None ,
195+ ):
196+ await asyncio .sleep (60 )
197+ return await super ().call_tool (tool_name , arguments , meta = meta )
198+
199+
200+ class CleanupOnCancelFakeMCPServer (FakeMCPServer ):
201+ def __init__ (self , cleanup_finished : asyncio .Event ):
202+ super ().__init__ ()
203+ self .cleanup_finished = cleanup_finished
204+
205+ async def call_tool (
206+ self ,
207+ tool_name : str ,
208+ arguments : dict [str , Any ] | None ,
209+ meta : dict [str , Any ] | None = None ,
210+ ):
211+ try :
212+ await asyncio .sleep (60 )
213+ except asyncio .CancelledError :
214+ await asyncio .sleep (0.05 )
215+ self .cleanup_finished .set ()
216+ raise
217+
218+
178219@pytest .mark .asyncio
179220async def test_mcp_invocation_crash_causes_error (caplog : pytest .LogCaptureFixture ):
180221 caplog .set_level (logging .DEBUG )
@@ -192,6 +233,159 @@ async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixtur
192233 assert "Error invoking MCP tool test_tool_1" in caplog .text
193234
194235
236+ @pytest .mark .asyncio
237+ async def test_mcp_tool_inner_cancellation_becomes_tool_error ():
238+ server = CancelledFakeMCPServer ()
239+ server .add_tool ("cancel_tool" , {})
240+
241+ ctx = RunContextWrapper (context = None )
242+ tool = MCPTool (name = "cancel_tool" , inputSchema = {})
243+
244+ with pytest .raises (UserError , match = "tool execution was cancelled" ):
245+ await MCPUtil .invoke_mcp_tool (server , tool , ctx , "{}" )
246+
247+ agent = Agent (name = "test-agent" )
248+ function_tool = MCPUtil .to_function_tool (
249+ tool , server , convert_schemas_to_strict = False , agent = agent
250+ )
251+ tool_context = ToolContext (
252+ context = None ,
253+ tool_name = "cancel_tool" ,
254+ tool_call_id = "test_call_cancelled" ,
255+ tool_arguments = "{}" ,
256+ )
257+
258+ result = await function_tool .on_invoke_tool (tool_context , "{}" )
259+ assert isinstance (result , str )
260+ assert "tool execution was cancelled" in result
261+
262+
263+ @pytest .mark .asyncio
264+ async def test_mcp_tool_inner_cancellation_still_becomes_tool_error_with_prior_cancel_state ():
265+ current_task = asyncio .current_task ()
266+ assert current_task is not None
267+
268+ current_task .cancel ()
269+ with pytest .raises (asyncio .CancelledError ):
270+ await asyncio .sleep (0 )
271+
272+ server = CancelledFakeMCPServer ()
273+ server .add_tool ("cancel_tool" , {})
274+
275+ ctx = RunContextWrapper (context = None )
276+ tool = MCPTool (name = "cancel_tool" , inputSchema = {})
277+
278+ with pytest .raises (UserError , match = "tool execution was cancelled" ):
279+ await MCPUtil .invoke_mcp_tool (server , tool , ctx , "{}" )
280+
281+
282+ @pytest .mark .asyncio
283+ async def test_mcp_tool_outer_cancellation_still_propagates ():
284+ server = SlowFakeMCPServer ()
285+ server .add_tool ("slow_tool" , {})
286+
287+ ctx = RunContextWrapper (context = None )
288+ tool = MCPTool (name = "slow_tool" , inputSchema = {})
289+
290+ task = asyncio .create_task (MCPUtil .invoke_mcp_tool (server , tool , ctx , "{}" ))
291+ await asyncio .sleep (0.05 )
292+ task .cancel ()
293+
294+ with pytest .raises (asyncio .CancelledError ):
295+ await task
296+
297+
298+ @pytest .mark .asyncio
299+ async def test_mcp_tool_outer_cancellation_after_inner_completion_still_propagates (
300+ monkeypatch : pytest .MonkeyPatch ,
301+ ):
302+ server = FakeMCPServer ()
303+ server .add_tool ("fast_tool" , {})
304+
305+ ctx = RunContextWrapper (context = None )
306+ tool = MCPTool (name = "fast_tool" , inputSchema = {})
307+
308+ async def fake_wait (tasks , * , return_when ):
309+ del return_when
310+ (task ,) = tuple (tasks )
311+ await task
312+ raise asyncio .CancelledError ("synthetic outer cancellation" )
313+
314+ monkeypatch .setattr (asyncio , "wait" , fake_wait )
315+
316+ with pytest .raises (asyncio .CancelledError ):
317+ await MCPUtil .invoke_mcp_tool (server , tool , ctx , "{}" )
318+
319+
320+ @pytest .mark .asyncio
321+ async def test_mcp_tool_outer_cancellation_after_inner_exception_still_propagates (
322+ monkeypatch : pytest .MonkeyPatch ,
323+ ):
324+ server = CrashingFakeMCPServer ()
325+ server .add_tool ("boom_tool" , {})
326+
327+ ctx = RunContextWrapper (context = None )
328+ tool = MCPTool (name = "boom_tool" , inputSchema = {})
329+
330+ async def fake_wait (tasks , * , return_when ):
331+ del return_when
332+ (task ,) = tuple (tasks )
333+ try :
334+ await task
335+ except Exception :
336+ pass
337+ raise asyncio .CancelledError ("synthetic outer cancellation" )
338+
339+ monkeypatch .setattr (asyncio , "wait" , fake_wait )
340+
341+ with pytest .raises (asyncio .CancelledError ):
342+ await MCPUtil .invoke_mcp_tool (server , tool , ctx , "{}" )
343+
344+
345+ @pytest .mark .asyncio
346+ async def test_mcp_tool_outer_cancellation_after_inner_cancellation_still_propagates (
347+ monkeypatch : pytest .MonkeyPatch ,
348+ ):
349+ server = SlowFakeMCPServer ()
350+ server .add_tool ("slow_tool" , {})
351+
352+ ctx = RunContextWrapper (context = None )
353+ tool = MCPTool (name = "slow_tool" , inputSchema = {})
354+
355+ async def fake_wait (tasks , * , return_when ):
356+ del return_when
357+ (task ,) = tuple (tasks )
358+ task .cancel ()
359+ with pytest .raises (asyncio .CancelledError ):
360+ await task
361+
362+ raise asyncio .CancelledError ("synthetic combined cancellation" )
363+
364+ monkeypatch .setattr (asyncio , "wait" , fake_wait )
365+
366+ with pytest .raises (asyncio .CancelledError ):
367+ await MCPUtil .invoke_mcp_tool (server , tool , ctx , "{}" )
368+
369+
370+ @pytest .mark .asyncio
371+ async def test_mcp_tool_outer_cancellation_waits_for_inner_cleanup ():
372+ cleanup_finished = asyncio .Event ()
373+ server = CleanupOnCancelFakeMCPServer (cleanup_finished )
374+ server .add_tool ("slow_tool" , {})
375+
376+ ctx = RunContextWrapper (context = None )
377+ tool = MCPTool (name = "slow_tool" , inputSchema = {})
378+
379+ task = asyncio .create_task (MCPUtil .invoke_mcp_tool (server , tool , ctx , "{}" ))
380+ await asyncio .sleep (0.05 )
381+ task .cancel ()
382+
383+ with pytest .raises (asyncio .CancelledError ):
384+ await task
385+
386+ assert cleanup_finished .is_set ()
387+
388+
195389@pytest .mark .asyncio
196390async def test_mcp_invocation_mcp_error_reraises (caplog : pytest .LogCaptureFixture ):
197391 """Test that McpError from server.call_tool is re-raised so the FunctionTool failure
0 commit comments