Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/agents/realtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,16 +746,32 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
)
)

def _cleanup_guardrail_tasks(self) -> None:
async def _cleanup_guardrail_tasks(self) -> None:
"""Cancel all pending guardrail tasks and wait for them to complete.

This ensures that any exceptions raised by the tasks are properly handled
and prevents warnings about unhandled task exceptions.
"""
# Collect real asyncio.Task objects that need to be awaited
real_tasks = []

for task in self._guardrail_tasks:
if not task.done():
task.cancel()
# Only await real asyncio.Task objects (not mocks in tests)
if isinstance(task, asyncio.Task):
real_tasks.append(task)

# Wait for all real tasks to complete and collect any exceptions
if real_tasks:
await asyncio.gather(*real_tasks, return_exceptions=True)

self._guardrail_tasks.clear()

async def _cleanup(self) -> None:
"""Clean up all resources and mark session as closed."""
# Cancel and cleanup guardrail tasks
self._cleanup_guardrail_tasks()
await self._cleanup_guardrail_tasks()

# Remove ourselves as a listener
self._model.remove_listener(self)
Expand Down
246 changes: 246 additions & 0 deletions tests/realtime/test_guardrail_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
"""Test guardrail task cleanup to ensure proper exception handling.

This test verifies the fix for the bug where _cleanup_guardrail_tasks() was not
properly awaiting cancelled tasks, which could lead to unhandled task exceptions
and potential memory leaks.
"""

import asyncio
from unittest.mock import AsyncMock, Mock, PropertyMock

import pytest

from agents.guardrail import GuardrailFunctionOutput, OutputGuardrail
from agents.realtime import RealtimeSession
from agents.realtime.agent import RealtimeAgent
from agents.realtime.config import RealtimeRunConfig
from agents.realtime.model import RealtimeModel
from agents.realtime.model_events import RealtimeModelTranscriptDeltaEvent


class MockRealtimeModel(RealtimeModel):
"""Mock realtime model for testing."""

def __init__(self):
super().__init__()
self.listeners = []
self.connect_called = False
self.close_called = False
self.sent_events = []
self.sent_messages = []
self.sent_audio = []
self.sent_tool_outputs = []
self.interrupts_called = 0

async def connect(self, options=None):
self.connect_called = True

def add_listener(self, listener):
self.listeners.append(listener)

def remove_listener(self, listener):
if listener in self.listeners:
self.listeners.remove(listener)

async def send_event(self, event):
from agents.realtime.model_inputs import (
RealtimeModelSendAudio,
RealtimeModelSendInterrupt,
RealtimeModelSendToolOutput,
RealtimeModelSendUserInput,
)

self.sent_events.append(event)

# Update legacy tracking for compatibility
if isinstance(event, RealtimeModelSendUserInput):
self.sent_messages.append(event.user_input)
elif isinstance(event, RealtimeModelSendAudio):
self.sent_audio.append((event.audio, event.commit))
elif isinstance(event, RealtimeModelSendToolOutput):
self.sent_tool_outputs.append((event.tool_call, event.output, event.start_response))
elif isinstance(event, RealtimeModelSendInterrupt):
self.interrupts_called += 1

async def close(self):
self.close_called = True


@pytest.fixture
def mock_model():
return MockRealtimeModel()


@pytest.fixture
def mock_agent():
agent = Mock(spec=RealtimeAgent)
agent.name = "test_agent"
agent.get_all_tools = AsyncMock(return_value=[])
type(agent).handoffs = PropertyMock(return_value=[])
type(agent).output_guardrails = PropertyMock(return_value=[])
return agent


@pytest.mark.asyncio
async def test_guardrail_task_cleanup_awaits_cancelled_tasks(mock_model, mock_agent):
"""Test that cleanup properly awaits cancelled guardrail tasks.

This test verifies that when guardrail tasks are cancelled during cleanup,
the cleanup method properly awaits them to completion using asyncio.gather()
with return_exceptions=True. This ensures:
1. No warnings about unhandled task exceptions
2. Proper resource cleanup
3. No memory leaks from abandoned tasks
"""

# Create a guardrail that runs a long async operation
task_started = asyncio.Event()
task_cancelled = asyncio.Event()

async def slow_guardrail_func(context, agent, output):
"""A guardrail that takes time to execute."""
task_started.set()
try:
# Simulate a long-running operation
await asyncio.sleep(10)
return GuardrailFunctionOutput(output_info={}, tripwire_triggered=False)
except asyncio.CancelledError:
task_cancelled.set()
raise

guardrail = OutputGuardrail(guardrail_function=slow_guardrail_func, name="slow_guardrail")

run_config: RealtimeRunConfig = {
"output_guardrails": [guardrail],
"guardrails_settings": {"debounce_text_length": 5},
}

session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)

# Trigger a guardrail by sending a transcript delta
transcript_event = RealtimeModelTranscriptDeltaEvent(
item_id="item_1", delta="hello world", response_id="resp_1"
)

await session.on_event(transcript_event)

# Wait for the guardrail task to start
await asyncio.wait_for(task_started.wait(), timeout=1.0)

# Verify a guardrail task was created
assert len(session._guardrail_tasks) == 1
task = list(session._guardrail_tasks)[0]
assert not task.done()

# Now cleanup the session - this should cancel and await the task
await session._cleanup_guardrail_tasks()

# Verify the task was cancelled and properly awaited
assert task_cancelled.is_set(), "Task should have received CancelledError"
assert len(session._guardrail_tasks) == 0, "Tasks list should be cleared"

# No warnings should be raised about unhandled task exceptions


@pytest.mark.asyncio
async def test_guardrail_task_cleanup_with_exception(mock_model, mock_agent):
"""Test that cleanup handles guardrail tasks that raise exceptions.

This test verifies that if a guardrail task raises an exception (not just
CancelledError), the cleanup method still completes successfully and doesn't
propagate the exception, thanks to return_exceptions=True.
"""

task_started = asyncio.Event()
exception_raised = asyncio.Event()

async def failing_guardrail_func(context, agent, output):
"""A guardrail that raises an exception."""
task_started.set()
try:
await asyncio.sleep(10)
return GuardrailFunctionOutput(output_info={}, tripwire_triggered=False)
except asyncio.CancelledError as e:
exception_raised.set()
# Simulate an error during cleanup
raise RuntimeError("Cleanup error") from e

guardrail = OutputGuardrail(
guardrail_function=failing_guardrail_func, name="failing_guardrail"
)

run_config: RealtimeRunConfig = {
"output_guardrails": [guardrail],
"guardrails_settings": {"debounce_text_length": 5},
}

session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)

# Trigger a guardrail
transcript_event = RealtimeModelTranscriptDeltaEvent(
item_id="item_1", delta="hello world", response_id="resp_1"
)

await session.on_event(transcript_event)

# Wait for the guardrail task to start
await asyncio.wait_for(task_started.wait(), timeout=1.0)

# Cleanup should not raise the RuntimeError due to return_exceptions=True
await session._cleanup_guardrail_tasks()

# Verify cleanup completed successfully
assert exception_raised.is_set()
assert len(session._guardrail_tasks) == 0


@pytest.mark.asyncio
async def test_guardrail_task_cleanup_with_multiple_tasks(mock_model, mock_agent):
"""Test cleanup with multiple pending guardrail tasks.

This test verifies that cleanup properly handles multiple concurrent guardrail
tasks by triggering guardrails multiple times, then cancelling and awaiting all of them.
"""

tasks_started = asyncio.Event()
tasks_cancelled = 0

async def slow_guardrail_func(context, agent, output):
nonlocal tasks_cancelled
tasks_started.set()
try:
await asyncio.sleep(10)
return GuardrailFunctionOutput(output_info={}, tripwire_triggered=False)
except asyncio.CancelledError:
tasks_cancelled += 1
raise

guardrail = OutputGuardrail(guardrail_function=slow_guardrail_func, name="slow_guardrail")

run_config: RealtimeRunConfig = {
"output_guardrails": [guardrail],
"guardrails_settings": {"debounce_text_length": 5},
}

session = RealtimeSession(mock_model, mock_agent, None, run_config=run_config)

# Trigger guardrails multiple times to create multiple tasks
for i in range(3):
transcript_event = RealtimeModelTranscriptDeltaEvent(
item_id=f"item_{i}", delta="hello world", response_id=f"resp_{i}"
)
await session.on_event(transcript_event)

# Wait for at least one task to start
await asyncio.wait_for(tasks_started.wait(), timeout=1.0)

# Should have at least one guardrail task
initial_task_count = len(session._guardrail_tasks)
assert initial_task_count >= 1, "At least one guardrail task should exist"

# Cleanup should cancel and await all tasks
await session._cleanup_guardrail_tasks()

# Verify all tasks were cancelled and cleared
assert tasks_cancelled >= 1, "At least one task should have been cancelled"
assert len(session._guardrail_tasks) == 0