Skip to content

Commit caaefe8

Browse files
authored
Merge branch 'main' into fix-voice-stt-task-cleanup
2 parents 59ec0eb + 8c4d4d0 commit caaefe8

File tree

6 files changed

+124
-16
lines changed

6 files changed

+124
-16
lines changed

examples/realtime/app/agent.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
13
from agents import function_tool
24
from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX
35
from agents.realtime import RealtimeAgent, realtime_handoff
@@ -13,20 +15,26 @@
1315
name_override="faq_lookup_tool", description_override="Lookup frequently asked questions."
1416
)
1517
async def faq_lookup_tool(question: str) -> str:
16-
if "bag" in question or "baggage" in question:
18+
print("faq_lookup_tool called with question:", question)
19+
20+
# Simulate a slow API call
21+
await asyncio.sleep(3)
22+
23+
q = question.lower()
24+
if "wifi" in q or "wi-fi" in q:
25+
return "We have free wifi on the plane, join Airline-Wifi"
26+
elif "bag" in q or "baggage" in q:
1727
return (
1828
"You are allowed to bring one bag on the plane. "
1929
"It must be under 50 pounds and 22 inches x 14 inches x 9 inches."
2030
)
21-
elif "seats" in question or "plane" in question:
31+
elif "seats" in q or "plane" in q:
2232
return (
2333
"There are 120 seats on the plane. "
2434
"There are 22 business class seats and 98 economy seats. "
2535
"Exit rows are rows 4 and 16. "
2636
"Rows 5-8 are Economy Plus, with extra legroom. "
2737
)
28-
elif "wifi" in question:
29-
return "We have free wifi on the plane, join Airline-Wifi"
3038
return "I'm sorry, I don't know the answer to that question."
3139

3240

examples/realtime/app/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ async def connect(self, websocket: WebSocket, session_id: str):
4747

4848
agent = get_starting_agent()
4949
runner = RealtimeRunner(agent)
50+
# If you want to customize the runner behavior, you can pass options:
51+
# runner_config = RealtimeRunConfig(async_tool_calls=False)
52+
# runner = RealtimeRunner(agent, config=runner_config)
5053
model_config: RealtimeModelConfig = {
5154
"initial_model_settings": {
5255
"turn_detection": {

src/agents/models/chatcmpl_stream_handler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ async def handle_stream(
150150
)
151151

152152
if reasoning_content and state.reasoning_content_index_and_output:
153+
# Ensure summary list has at least one element
154+
if not state.reasoning_content_index_and_output[1].summary:
155+
state.reasoning_content_index_and_output[1].summary = [
156+
Summary(text="", type="summary_text")
157+
]
158+
153159
yield ResponseReasoningSummaryTextDeltaEvent(
154160
delta=reasoning_content,
155161
item_id=FAKE_RESPONSES_ID,
@@ -201,7 +207,7 @@ async def handle_stream(
201207
)
202208

203209
# Create a new summary with updated text
204-
if state.reasoning_content_index_and_output[1].content is None:
210+
if not state.reasoning_content_index_and_output[1].content:
205211
state.reasoning_content_index_and_output[1].content = [
206212
Content(text="", type="reasoning_text")
207213
]

src/agents/realtime/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ class RealtimeRunConfig(TypedDict):
184184
tracing_disabled: NotRequired[bool]
185185
"""Whether tracing is disabled for this run."""
186186

187+
async_tool_calls: NotRequired[bool]
188+
"""Whether function tool calls should run asynchronously. Defaults to True."""
189+
187190
# TODO (rm) Add history audio storage config
188191

189192

src/agents/realtime/session.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
}
113113
self._event_queue: asyncio.Queue[RealtimeSessionEvent] = asyncio.Queue()
114114
self._closed = False
115-
self._stored_exception: Exception | None = None
115+
self._stored_exception: BaseException | None = None
116116

117117
# Guardrails state tracking
118118
self._interrupted_response_ids: set[str] = set()
@@ -123,6 +123,8 @@ def __init__(
123123
)
124124

125125
self._guardrail_tasks: set[asyncio.Task[Any]] = set()
126+
self._tool_call_tasks: set[asyncio.Task[Any]] = set()
127+
self._async_tool_calls: bool = bool(self._run_config.get("async_tool_calls", True))
126128

127129
@property
128130
def model(self) -> RealtimeModel:
@@ -216,7 +218,11 @@ async def on_event(self, event: RealtimeModelEvent) -> None:
216218
if event.type == "error":
217219
await self._put_event(RealtimeError(info=self._event_info, error=event.error))
218220
elif event.type == "function_call":
219-
await self._handle_tool_call(event)
221+
agent_snapshot = self._current_agent
222+
if self._async_tool_calls:
223+
self._enqueue_tool_call_task(event, agent_snapshot)
224+
else:
225+
await self._handle_tool_call(event, agent_snapshot=agent_snapshot)
220226
elif event.type == "audio":
221227
await self._put_event(
222228
RealtimeAudio(
@@ -384,11 +390,17 @@ async def _put_event(self, event: RealtimeSessionEvent) -> None:
384390
"""Put an event into the queue."""
385391
await self._event_queue.put(event)
386392

387-
async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
393+
async def _handle_tool_call(
394+
self,
395+
event: RealtimeModelToolCallEvent,
396+
*,
397+
agent_snapshot: RealtimeAgent | None = None,
398+
) -> None:
388399
"""Handle a tool call event."""
400+
agent = agent_snapshot or self._current_agent
389401
tools, handoffs = await asyncio.gather(
390-
self._current_agent.get_all_tools(self._context_wrapper),
391-
self._get_handoffs(self._current_agent, self._context_wrapper),
402+
agent.get_all_tools(self._context_wrapper),
403+
self._get_handoffs(agent, self._context_wrapper),
392404
)
393405
function_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
394406
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
@@ -398,7 +410,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
398410
RealtimeToolStart(
399411
info=self._event_info,
400412
tool=function_map[event.name],
401-
agent=self._current_agent,
413+
agent=agent,
402414
)
403415
)
404416

@@ -423,7 +435,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
423435
info=self._event_info,
424436
tool=func_tool,
425437
output=result,
426-
agent=self._current_agent,
438+
agent=agent,
427439
)
428440
)
429441
elif event.name in handoff_map:
@@ -444,7 +456,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
444456
)
445457

446458
# Store previous agent for event
447-
previous_agent = self._current_agent
459+
previous_agent = agent
448460

449461
# Update current agent
450462
self._current_agent = result
@@ -752,10 +764,49 @@ def _cleanup_guardrail_tasks(self) -> None:
752764
task.cancel()
753765
self._guardrail_tasks.clear()
754766

767+
def _enqueue_tool_call_task(
768+
self, event: RealtimeModelToolCallEvent, agent_snapshot: RealtimeAgent
769+
) -> None:
770+
"""Run tool calls in the background to avoid blocking realtime transport."""
771+
task = asyncio.create_task(self._handle_tool_call(event, agent_snapshot=agent_snapshot))
772+
self._tool_call_tasks.add(task)
773+
task.add_done_callback(self._on_tool_call_task_done)
774+
775+
def _on_tool_call_task_done(self, task: asyncio.Task[Any]) -> None:
776+
self._tool_call_tasks.discard(task)
777+
778+
if task.cancelled():
779+
return
780+
781+
exception = task.exception()
782+
if exception is None:
783+
return
784+
785+
logger.exception("Realtime tool call task failed", exc_info=exception)
786+
787+
if self._stored_exception is None:
788+
self._stored_exception = exception
789+
790+
asyncio.create_task(
791+
self._put_event(
792+
RealtimeError(
793+
info=self._event_info,
794+
error={"message": f"Tool call task failed: {exception}"},
795+
)
796+
)
797+
)
798+
799+
def _cleanup_tool_call_tasks(self) -> None:
800+
for task in self._tool_call_tasks:
801+
if not task.done():
802+
task.cancel()
803+
self._tool_call_tasks.clear()
804+
755805
async def _cleanup(self) -> None:
756806
"""Clean up all resources and mark session as closed."""
757807
# Cancel and cleanup guardrail tasks
758808
self._cleanup_guardrail_tasks()
809+
self._cleanup_tool_call_tasks()
759810

760811
# Remove ourselves as a listener
761812
self._model.remove_listener(self)

tests/realtime/test_session.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -561,8 +561,13 @@ async def test_ignored_events_only_generate_raw_events(self, mock_model, mock_ag
561561

562562
@pytest.mark.asyncio
563563
async def test_function_call_event_triggers_tool_handling(self, mock_model, mock_agent):
564-
"""Test that function_call events trigger tool call handling"""
565-
session = RealtimeSession(mock_model, mock_agent, None)
564+
"""Test that function_call events trigger tool call handling synchronously when disabled"""
565+
session = RealtimeSession(
566+
mock_model,
567+
mock_agent,
568+
None,
569+
run_config={"async_tool_calls": False},
570+
)
566571

567572
# Create function call event
568573
function_call_event = RealtimeModelToolCallEvent(
@@ -578,14 +583,46 @@ async def test_function_call_event_triggers_tool_handling(self, mock_model, mock
578583
await session.on_event(function_call_event)
579584

580585
# Should have called the tool handler
581-
handle_tool_call_mock.assert_called_once_with(function_call_event)
586+
handle_tool_call_mock.assert_called_once_with(
587+
function_call_event, agent_snapshot=mock_agent
588+
)
582589

583590
# Should still have raw event
584591
assert session._event_queue.qsize() == 1
585592
raw_event = await session._event_queue.get()
586593
assert isinstance(raw_event, RealtimeRawModelEvent)
587594
assert raw_event.data == function_call_event
588595

596+
@pytest.mark.asyncio
597+
async def test_function_call_event_runs_async_by_default(self, mock_model, mock_agent):
598+
"""Function call handling should be scheduled asynchronously by default"""
599+
session = RealtimeSession(mock_model, mock_agent, None)
600+
601+
function_call_event = RealtimeModelToolCallEvent(
602+
name="test_function",
603+
call_id="call_async",
604+
arguments='{"param": "value"}',
605+
)
606+
607+
with pytest.MonkeyPatch().context() as m:
608+
handle_tool_call_mock = AsyncMock()
609+
m.setattr(session, "_handle_tool_call", handle_tool_call_mock)
610+
611+
await session.on_event(function_call_event)
612+
613+
# Let the background task run
614+
await asyncio.sleep(0)
615+
616+
handle_tool_call_mock.assert_awaited_once_with(
617+
function_call_event, agent_snapshot=mock_agent
618+
)
619+
620+
# Raw event still enqueued
621+
assert session._event_queue.qsize() == 1
622+
raw_event = await session._event_queue.get()
623+
assert isinstance(raw_event, RealtimeRawModelEvent)
624+
assert raw_event.data == function_call_event
625+
589626

590627
class TestHistoryManagement:
591628
"""Test suite for history management and audio transcription in

0 commit comments

Comments
 (0)