Skip to content

Commit c7d50cb

Browse files
authored
Realtime: use SDK types for all messages (#1134)
To ensure we are using correct events, using the realtime types from the openai sdk. Added conversion tests as well. --- [//]: # (BEGIN SAPLING FOOTER) * #1135 * __->__ #1134
1 parent 8fdbe09 commit c7d50cb

File tree

7 files changed

+575
-140
lines changed

7 files changed

+575
-140
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ requires-python = ">=3.9"
77
license = "MIT"
88
authors = [{ name = "OpenAI", email = "[email protected]" }]
99
dependencies = [
10-
"openai>=1.93.1, <2",
10+
"openai>=1.96.0, <2",
1111
"pydantic>=2.10, <3",
1212
"griffe>=1.5.6, <2",
1313
"typing-extensions>=4.12.2, <5",

src/agents/realtime/items.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class AssistantMessageItem(BaseModel):
7373
class RealtimeToolCallItem(BaseModel):
7474
item_id: str
7575
previous_item_id: str | None = None
76+
call_id: str | None
7677
type: Literal["function_call"] = "function_call"
7778
status: Literal["in_progress", "completed"]
7879
arguments: str

src/agents/realtime/openai_realtime.py

Lines changed: 165 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,47 @@
1010

1111
import pydantic
1212
import websockets
13-
from openai.types.beta.realtime.conversation_item import ConversationItem
13+
from openai.types.beta.realtime.conversation_item import (
14+
ConversationItem,
15+
ConversationItem as OpenAIConversationItem,
16+
)
17+
from openai.types.beta.realtime.conversation_item_content import (
18+
ConversationItemContent as OpenAIConversationItemContent,
19+
)
20+
from openai.types.beta.realtime.conversation_item_create_event import (
21+
ConversationItemCreateEvent as OpenAIConversationItemCreateEvent,
22+
)
23+
from openai.types.beta.realtime.conversation_item_retrieve_event import (
24+
ConversationItemRetrieveEvent as OpenAIConversationItemRetrieveEvent,
25+
)
26+
from openai.types.beta.realtime.conversation_item_truncate_event import (
27+
ConversationItemTruncateEvent as OpenAIConversationItemTruncateEvent,
28+
)
29+
from openai.types.beta.realtime.input_audio_buffer_append_event import (
30+
InputAudioBufferAppendEvent as OpenAIInputAudioBufferAppendEvent,
31+
)
32+
from openai.types.beta.realtime.input_audio_buffer_commit_event import (
33+
InputAudioBufferCommitEvent as OpenAIInputAudioBufferCommitEvent,
34+
)
35+
from openai.types.beta.realtime.realtime_client_event import (
36+
RealtimeClientEvent as OpenAIRealtimeClientEvent,
37+
)
1438
from openai.types.beta.realtime.realtime_server_event import (
1539
RealtimeServerEvent as OpenAIRealtimeServerEvent,
1640
)
1741
from openai.types.beta.realtime.response_audio_delta_event import ResponseAudioDeltaEvent
42+
from openai.types.beta.realtime.response_cancel_event import (
43+
ResponseCancelEvent as OpenAIResponseCancelEvent,
44+
)
45+
from openai.types.beta.realtime.response_create_event import (
46+
ResponseCreateEvent as OpenAIResponseCreateEvent,
47+
)
1848
from openai.types.beta.realtime.session_update_event import (
1949
Session as OpenAISessionObject,
2050
SessionTool as OpenAISessionTool,
51+
SessionTracing as OpenAISessionTracing,
52+
SessionTracingTracingConfiguration as OpenAISessionTracingConfiguration,
53+
SessionUpdateEvent as OpenAISessionUpdateEvent,
2154
)
2255
from pydantic import TypeAdapter
2356
from typing_extensions import assert_never
@@ -135,12 +168,11 @@ async def _send_tracing_config(
135168
) -> None:
136169
"""Update tracing configuration via session.update event."""
137170
if tracing_config is not None:
171+
converted_tracing_config = _ConversionHelper.convert_tracing_config(tracing_config)
138172
await self._send_raw_message(
139-
RealtimeModelSendRawMessage(
140-
message={
141-
"type": "session.update",
142-
"other_data": {"session": {"tracing": tracing_config}},
143-
}
173+
OpenAISessionUpdateEvent(
174+
session=OpenAISessionObject(tracing=converted_tracing_config),
175+
type="session.update",
144176
)
145177
)
146178

@@ -199,7 +231,11 @@ async def _listen_for_messages(self):
199231
async def send_event(self, event: RealtimeModelSendEvent) -> None:
200232
"""Send an event to the model."""
201233
if isinstance(event, RealtimeModelSendRawMessage):
202-
await self._send_raw_message(event)
234+
converted = _ConversionHelper.try_convert_raw_message(event)
235+
if converted is not None:
236+
await self._send_raw_message(converted)
237+
else:
238+
logger.error(f"Failed to convert raw message: {event}")
203239
elif isinstance(event, RealtimeModelSendUserInput):
204240
await self._send_user_input(event)
205241
elif isinstance(event, RealtimeModelSendAudio):
@@ -214,77 +250,33 @@ async def send_event(self, event: RealtimeModelSendEvent) -> None:
214250
assert_never(event)
215251
raise ValueError(f"Unknown event type: {type(event)}")
216252

217-
async def _send_raw_message(self, event: RealtimeModelSendRawMessage) -> None:
253+
async def _send_raw_message(self, event: OpenAIRealtimeClientEvent) -> None:
218254
"""Send a raw message to the model."""
219255
assert self._websocket is not None, "Not connected"
220256

221-
converted_event = {
222-
"type": event.message["type"],
223-
}
224-
225-
converted_event.update(event.message.get("other_data", {}))
226-
227-
await self._websocket.send(json.dumps(converted_event))
257+
await self._websocket.send(event.model_dump_json(exclude_none=True, exclude_unset=True))
228258

229259
async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None:
230-
message = (
231-
event.user_input
232-
if isinstance(event.user_input, dict)
233-
else {
234-
"type": "message",
235-
"role": "user",
236-
"content": [{"type": "input_text", "text": event.user_input}],
237-
}
238-
)
239-
other_data = {
240-
"item": message,
241-
}
242-
243-
await self._send_raw_message(
244-
RealtimeModelSendRawMessage(
245-
message={"type": "conversation.item.create", "other_data": other_data}
246-
)
247-
)
248-
await self._send_raw_message(
249-
RealtimeModelSendRawMessage(message={"type": "response.create"})
250-
)
260+
converted = _ConversionHelper.convert_user_input_to_item_create(event)
261+
await self._send_raw_message(converted)
262+
await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))
251263

252264
async def _send_audio(self, event: RealtimeModelSendAudio) -> None:
253-
base64_audio = base64.b64encode(event.audio).decode("utf-8")
254-
await self._send_raw_message(
255-
RealtimeModelSendRawMessage(
256-
message={
257-
"type": "input_audio_buffer.append",
258-
"other_data": {
259-
"audio": base64_audio,
260-
},
261-
}
262-
)
263-
)
265+
converted = _ConversionHelper.convert_audio_to_input_audio_buffer_append(event)
266+
await self._send_raw_message(converted)
264267
if event.commit:
265268
await self._send_raw_message(
266-
RealtimeModelSendRawMessage(message={"type": "input_audio_buffer.commit"})
269+
OpenAIInputAudioBufferCommitEvent(type="input_audio_buffer.commit")
267270
)
268271

269272
async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
270-
await self._send_raw_message(
271-
RealtimeModelSendRawMessage(
272-
message={
273-
"type": "conversation.item.create",
274-
"other_data": {
275-
"item": {
276-
"type": "function_call_output",
277-
"output": event.output,
278-
"call_id": event.tool_call.id,
279-
},
280-
},
281-
}
282-
)
283-
)
273+
converted = _ConversionHelper.convert_tool_output(event)
274+
await self._send_raw_message(converted)
284275

285276
tool_item = RealtimeToolCallItem(
286277
item_id=event.tool_call.id or "",
287278
previous_item_id=event.tool_call.previous_item_id,
279+
call_id=event.tool_call.call_id,
288280
type="function_call",
289281
status="completed",
290282
arguments=event.tool_call.arguments,
@@ -294,9 +286,7 @@ async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
294286
await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_item))
295287

296288
if event.start_response:
297-
await self._send_raw_message(
298-
RealtimeModelSendRawMessage(message={"type": "response.create"})
299-
)
289+
await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))
300290

301291
async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
302292
if not self._current_item_id or not self._audio_start_time:
@@ -307,18 +297,12 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
307297
elapsed_time_ms = (datetime.now() - self._audio_start_time).total_seconds() * 1000
308298
if elapsed_time_ms > 0 and elapsed_time_ms < self._audio_length_ms:
309299
await self._emit_event(RealtimeModelAudioInterruptedEvent())
310-
await self._send_raw_message(
311-
RealtimeModelSendRawMessage(
312-
message={
313-
"type": "conversation.item.truncate",
314-
"other_data": {
315-
"item_id": self._current_item_id,
316-
"content_index": self._current_audio_content_index,
317-
"audio_end_ms": elapsed_time_ms,
318-
},
319-
}
320-
)
300+
converted = _ConversionHelper.convert_interrupt(
301+
self._current_item_id,
302+
self._current_audio_content_index or 0,
303+
int(elapsed_time_ms),
321304
)
305+
await self._send_raw_message(converted)
322306

323307
self._current_item_id = None
324308
self._audio_start_time = None
@@ -354,6 +338,7 @@ async def _handle_output_item(self, item: ConversationItem) -> None:
354338
tool_call = RealtimeToolCallItem(
355339
item_id=item.id or "",
356340
previous_item_id=None,
341+
call_id=item.call_id,
357342
type="function_call",
358343
# We use the same item for tool call and output, so it will be completed by the
359344
# output being added
@@ -365,7 +350,7 @@ async def _handle_output_item(self, item: ConversationItem) -> None:
365350
await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_call))
366351
await self._emit_event(
367352
RealtimeModelToolCallEvent(
368-
call_id=item.id or "",
353+
call_id=item.call_id or "",
369354
name=item.name or "",
370355
arguments=item.arguments or "",
371356
id=item.id or "",
@@ -404,9 +389,7 @@ async def close(self) -> None:
404389

405390
async def _cancel_response(self) -> None:
406391
if self._ongoing_response:
407-
await self._send_raw_message(
408-
RealtimeModelSendRawMessage(message={"type": "response.cancel"})
409-
)
392+
await self._send_raw_message(OpenAIResponseCancelEvent(type="response.cancel"))
410393
self._ongoing_response = False
411394

412395
async def _handle_ws_event(self, event: dict[str, Any]):
@@ -466,16 +449,13 @@ async def _handle_ws_event(self, event: dict[str, Any]):
466449
parsed.type == "conversation.item.input_audio_transcription.completed"
467450
or parsed.type == "conversation.item.truncated"
468451
):
469-
await self._send_raw_message(
470-
RealtimeModelSendRawMessage(
471-
message={
472-
"type": "conversation.item.retrieve",
473-
"other_data": {
474-
"item_id": self._current_item_id,
475-
},
476-
}
452+
if self._current_item_id:
453+
await self._send_raw_message(
454+
OpenAIConversationItemRetrieveEvent(
455+
type="conversation.item.retrieve",
456+
item_id=self._current_item_id,
457+
)
477458
)
478-
)
479459
if parsed.type == "conversation.item.input_audio_transcription.completed":
480460
await self._emit_event(
481461
RealtimeModelInputAudioTranscriptionCompletedEvent(
@@ -504,14 +484,7 @@ async def _handle_ws_event(self, event: dict[str, Any]):
504484
async def _update_session_config(self, model_settings: RealtimeSessionModelSettings) -> None:
505485
session_config = self._get_session_config(model_settings)
506486
await self._send_raw_message(
507-
RealtimeModelSendRawMessage(
508-
message={
509-
"type": "session.update",
510-
"other_data": {
511-
"session": session_config.model_dump(exclude_unset=True, exclude_none=True)
512-
},
513-
}
514-
)
487+
OpenAISessionUpdateEvent(session=session_config, type="session.update")
515488
)
516489

517490
def _get_session_config(
@@ -582,3 +555,98 @@ def conversation_item_to_realtime_message_item(
582555
"status": "in_progress",
583556
},
584557
)
558+
559+
@classmethod
560+
def try_convert_raw_message(
561+
cls, message: RealtimeModelSendRawMessage
562+
) -> OpenAIRealtimeClientEvent | None:
563+
try:
564+
data = {}
565+
data["type"] = message.message["type"]
566+
data.update(message.message.get("other_data", {}))
567+
return TypeAdapter(OpenAIRealtimeClientEvent).validate_python(data)
568+
except Exception:
569+
return None
570+
571+
@classmethod
572+
def convert_tracing_config(
573+
cls, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None
574+
) -> OpenAISessionTracing | None:
575+
if tracing_config is None:
576+
return None
577+
elif tracing_config == "auto":
578+
return "auto"
579+
return OpenAISessionTracingConfiguration(
580+
group_id=tracing_config.get("group_id"),
581+
metadata=tracing_config.get("metadata"),
582+
workflow_name=tracing_config.get("workflow_name"),
583+
)
584+
585+
@classmethod
586+
def convert_user_input_to_conversation_item(
587+
cls, event: RealtimeModelSendUserInput
588+
) -> OpenAIConversationItem:
589+
user_input = event.user_input
590+
591+
if isinstance(user_input, dict):
592+
return OpenAIConversationItem(
593+
type="message",
594+
role="user",
595+
content=[
596+
OpenAIConversationItemContent(
597+
type="input_text",
598+
text=item.get("text"),
599+
)
600+
for item in user_input.get("content", [])
601+
],
602+
)
603+
else:
604+
return OpenAIConversationItem(
605+
type="message",
606+
role="user",
607+
content=[OpenAIConversationItemContent(type="input_text", text=user_input)],
608+
)
609+
610+
@classmethod
611+
def convert_user_input_to_item_create(
612+
cls, event: RealtimeModelSendUserInput
613+
) -> OpenAIRealtimeClientEvent:
614+
return OpenAIConversationItemCreateEvent(
615+
type="conversation.item.create",
616+
item=cls.convert_user_input_to_conversation_item(event),
617+
)
618+
619+
@classmethod
620+
def convert_audio_to_input_audio_buffer_append(
621+
cls, event: RealtimeModelSendAudio
622+
) -> OpenAIRealtimeClientEvent:
623+
base64_audio = base64.b64encode(event.audio).decode("utf-8")
624+
return OpenAIInputAudioBufferAppendEvent(
625+
type="input_audio_buffer.append",
626+
audio=base64_audio,
627+
)
628+
629+
@classmethod
630+
def convert_tool_output(cls, event: RealtimeModelSendToolOutput) -> OpenAIRealtimeClientEvent:
631+
return OpenAIConversationItemCreateEvent(
632+
type="conversation.item.create",
633+
item=OpenAIConversationItem(
634+
type="function_call_output",
635+
output=event.output,
636+
call_id=event.tool_call.call_id,
637+
),
638+
)
639+
640+
@classmethod
641+
def convert_interrupt(
642+
cls,
643+
current_item_id: str,
644+
current_audio_content_index: int,
645+
elapsed_time_ms: int,
646+
) -> OpenAIRealtimeClientEvent:
647+
return OpenAIConversationItemTruncateEvent(
648+
type="conversation.item.truncate",
649+
item_id=current_item_id,
650+
content_index=current_audio_content_index,
651+
audio_end_ms=elapsed_time_ms,
652+
)

0 commit comments

Comments
 (0)