@@ -378,3 +378,140 @@ async def test_inactivity_timeout():
378378 assert len (collected_turns ) == 0 , "No transcripts expected, but we got something?"
379379
380380 await session .close ()
381+
382+
383+ @pytest .mark .asyncio
384+ async def test_cleanup_tasks_cancels_and_awaits_all_tasks ():
385+ """
386+ Test that _cleanup_tasks() properly cancels and awaits all pending tasks.
387+ This ensures proper resource cleanup and prevents unhandled task exceptions.
388+ """
389+ mock_ws = create_mock_websocket (
390+ [
391+ json .dumps ({"type" : "transcription_session.created" }),
392+ json .dumps ({"type" : "transcription_session.updated" }),
393+ ]
394+ )
395+
396+ with patch ("websockets.connect" , return_value = mock_ws ):
397+ audio_input = await FakeStreamedAudioInput .get (count = 2 )
398+ stt_settings = STTModelSettings ()
399+
400+ session = OpenAISTTTranscriptionSession (
401+ input = audio_input ,
402+ client = AsyncMock (api_key = "FAKE_KEY" ),
403+ model = "whisper-1" ,
404+ settings = stt_settings ,
405+ trace_include_sensitive_data = False ,
406+ trace_include_sensitive_audio_data = False ,
407+ )
408+
409+ # Create some tasks to simulate active background tasks
410+ async def long_running_task ():
411+ try :
412+ await asyncio .sleep (10 )
413+ except asyncio .CancelledError :
414+ # Expected when cancelled
415+ raise
416+
417+ session ._listener_task = asyncio .create_task (long_running_task ())
418+ session ._process_events_task = asyncio .create_task (long_running_task ())
419+ session ._stream_audio_task = asyncio .create_task (long_running_task ())
420+ session ._connection_task = asyncio .create_task (long_running_task ())
421+
422+ # Verify tasks are running
423+ assert not session ._listener_task .done ()
424+ assert not session ._process_events_task .done ()
425+ assert not session ._stream_audio_task .done ()
426+ assert not session ._connection_task .done ()
427+
428+ # Call cleanup_tasks
429+ await session ._cleanup_tasks ()
430+
431+ # Verify all tasks were cancelled and completed
432+ assert session ._listener_task .cancelled ()
433+ assert session ._process_events_task .cancelled ()
434+ assert session ._stream_audio_task .cancelled ()
435+ assert session ._connection_task .cancelled ()
436+
437+
438+ @pytest .mark .asyncio
439+ async def test_cleanup_tasks_handles_exceptions ():
440+ """
441+ Test that _cleanup_tasks() properly handles exceptions from cancelled tasks
442+ without raising them (using return_exceptions=True).
443+ """
444+ mock_ws = create_mock_websocket (
445+ [
446+ json .dumps ({"type" : "transcription_session.created" }),
447+ json .dumps ({"type" : "transcription_session.updated" }),
448+ ]
449+ )
450+
451+ with patch ("websockets.connect" , return_value = mock_ws ):
452+ audio_input = await FakeStreamedAudioInput .get (count = 2 )
453+ stt_settings = STTModelSettings ()
454+
455+ session = OpenAISTTTranscriptionSession (
456+ input = audio_input ,
457+ client = AsyncMock (api_key = "FAKE_KEY" ),
458+ model = "whisper-1" ,
459+ settings = stt_settings ,
460+ trace_include_sensitive_data = False ,
461+ trace_include_sensitive_audio_data = False ,
462+ )
463+
464+ # Create tasks that raise exceptions when cancelled
465+ async def task_with_exception ():
466+ try :
467+ await asyncio .sleep (10 )
468+ except asyncio .CancelledError :
469+ raise RuntimeError ("Task exception during cancellation" )
470+
471+ session ._listener_task = asyncio .create_task (task_with_exception ())
472+ session ._process_events_task = asyncio .create_task (task_with_exception ())
473+
474+ # cleanup_tasks should not raise despite the exceptions
475+ await session ._cleanup_tasks ()
476+
477+ # Tasks should be done (cancelled or exception raised)
478+ assert session ._listener_task .done ()
479+ assert session ._process_events_task .done ()
480+
481+
482+ @pytest .mark .asyncio
483+ async def test_close_calls_cleanup_tasks ():
484+ """
485+ Test that close() properly calls _cleanup_tasks() to clean up background tasks.
486+ """
487+ mock_ws = create_mock_websocket (
488+ [
489+ json .dumps ({"type" : "transcription_session.created" }),
490+ json .dumps ({"type" : "transcription_session.updated" }),
491+ ]
492+ )
493+
494+ with patch ("websockets.connect" , return_value = mock_ws ):
495+ audio_input = await FakeStreamedAudioInput .get (count = 2 )
496+ stt_settings = STTModelSettings ()
497+
498+ session = OpenAISTTTranscriptionSession (
499+ input = audio_input ,
500+ client = AsyncMock (api_key = "FAKE_KEY" ),
501+ model = "whisper-1" ,
502+ settings = stt_settings ,
503+ trace_include_sensitive_data = False ,
504+ trace_include_sensitive_audio_data = False ,
505+ )
506+
507+ # Create a task
508+ async def long_running_task ():
509+ await asyncio .sleep (10 )
510+
511+ session ._listener_task = asyncio .create_task (long_running_task ())
512+
513+ # close() should cancel and await the task
514+ await session .close ()
515+
516+ # Task should be cancelled
517+ assert session ._listener_task .cancelled ()
0 commit comments