2121from pathlib import Path
2222from dotenv import load_dotenv
2323
24+ # Load environment variables BEFORE importing the agent
25+ load_dotenv ()
26+
27+ from google .genai import types
2428from google .genai .types import (
2529 Part ,
2630 Content ,
2731 Blob ,
2832)
2933
30- from google .adk .runners import InMemoryRunner
34+ from google .adk .runners import Runner
3135from google .adk .agents import LiveRequestQueue
32- from google .adk .agents .run_config import RunConfig
33- from google .genai import types
36+ from google .adk .agents .run_config import RunConfig , StreamingMode
37+ from google .adk . sessions . in_memory_session_service import InMemorySessionService
3438
3539from fastapi import FastAPI , WebSocket
3640from fastapi .staticfiles import StaticFiles
3741from fastapi .responses import FileResponse
42+ from fastapi .websockets import WebSocketDisconnect
3843
3944from google_search_agent .agent import root_agent
4045
4449# ADK Streaming
4550#
4651
47- # Load Gemini API Key
48- load_dotenv ()
52+ # Application configuration
53+ APP_NAME = "adk-streaming-ws"
4954
50- APP_NAME = "ADK Streaming example"
55+ # Initialize session service
56+ session_service = InMemorySessionService ()
5157
58+ # APP_NAME and session_service are defined in the Initialization section above
59+ runner = Runner (
60+ app_name = APP_NAME ,
61+ agent = root_agent ,
62+ session_service = session_service ,
63+ )
5264
5365async def start_agent_session (user_id , is_audio = False ):
5466 """Starts an agent session"""
5567
56- # Create a Runner
57- runner = InMemoryRunner (
68+ # Get or create session (recommended pattern for production)
69+ session_id = f"{ APP_NAME } _{ user_id } "
70+ session = await runner .session_service .get_session (
5871 app_name = APP_NAME ,
59- agent = root_agent ,
72+ user_id = user_id ,
73+ session_id = session_id ,
6074 )
75+ if not session :
76+ session = await runner .session_service .create_session (
77+ app_name = APP_NAME ,
78+ user_id = user_id ,
79+ session_id = session_id ,
80+ )
6181
62- # Create a Session
63- session = await runner .session_service .create_session (
64- app_name = APP_NAME ,
65- user_id = user_id , # Replace with actual user ID
66- )
82+ # Configure response format based on client preference
83+ # IMPORTANT: You must choose exactly ONE modality per session
84+ # Either ["TEXT"] for text responses OR ["AUDIO"] for voice responses
85+ # You cannot use both modalities simultaneously in the same session
86+
87+ # Force AUDIO modality for native audio models regardless of client preference
88+ model_name = root_agent .model if isinstance (root_agent .model , str ) else root_agent .model .model
89+ is_native_audio = "native-audio" in model_name .lower ()
6790
68- # Set response modality
69- modality = "AUDIO" if is_audio else "TEXT"
91+ modality = "AUDIO" if (is_audio or is_native_audio ) else "TEXT"
92+
93+ # Enable session resumption for improved reliability
94+ # For audio mode, enable output transcription to get text for UI display
7095 run_config = RunConfig (
96+ streaming_mode = StreamingMode .BIDI ,
7197 response_modalities = [modality ],
72- session_resumption = types .SessionResumptionConfig ()
98+ session_resumption = types .SessionResumptionConfig (),
99+ output_audio_transcription = types .AudioTranscriptionConfig () if (is_audio or is_native_audio ) else None ,
73100 )
74101
75- # Create a LiveRequestQueue for this session
102+ # Create LiveRequestQueue in async context (recommended best practice)
103+ # This ensures the queue uses the correct event loop
76104 live_request_queue = LiveRequestQueue ()
77105
78- # Start agent session
106+ # Start streaming session - returns async iterator for agent responses
79107 live_events = runner .run_live (
80- session = session ,
108+ user_id = user_id ,
109+ session_id = session .id ,
81110 live_request_queue = live_request_queue ,
82111 run_config = run_config ,
83112 )
@@ -86,69 +115,90 @@ async def start_agent_session(user_id, is_audio=False):
86115
87116async def agent_to_client_messaging (websocket , live_events ):
88117 """Agent to client communication"""
89- async for event in live_events :
90-
91- # If the turn complete or interrupted, send it
92- if event .turn_complete or event .interrupted :
93- message = {
94- "turn_complete" : event .turn_complete ,
95- "interrupted" : event .interrupted ,
96- }
97- await websocket .send_text (json .dumps (message ))
98- print (f"[AGENT TO CLIENT]: { message } " )
99- continue
100-
101- # Read the Content and its first Part
102- part : Part = (
103- event .content and event .content .parts and event .content .parts [0 ]
104- )
105- if not part :
106- continue
107-
108- # If it's audio, send Base64 encoded audio data
109- is_audio = part .inline_data and part .inline_data .mime_type .startswith ("audio/pcm" )
110- if is_audio :
111- audio_data = part .inline_data and part .inline_data .data
112- if audio_data :
118+ try :
119+ async for event in live_events :
120+
121+ # Handle output audio transcription for native audio models
122+ # This provides text representation of audio output for UI display
123+ if event .output_transcription and event .output_transcription .text :
124+ transcript_text = event .output_transcription .text
113125 message = {
114- "mime_type" : "audio/pcm" ,
115- "data" : base64 .b64encode (audio_data ).decode ("ascii" )
126+ "mime_type" : "text/plain" ,
127+ "data" : transcript_text ,
128+ "is_transcript" : True
116129 }
117130 await websocket .send_text (json .dumps (message ))
118- print (f"[AGENT TO CLIENT]: audio/pcm: { len (audio_data )} bytes." )
119- continue
120-
121- # If it's text and a partial text, send it
122- if part .text and event .partial :
123- message = {
124- "mime_type" : "text/plain" ,
125- "data" : part .text
126- }
127- await websocket .send_text (json .dumps (message ))
128- print (f"[AGENT TO CLIENT]: text/plain: { message } " )
131+ print (f"[AGENT TO CLIENT]: audio transcript: { transcript_text } " )
132+ # Continue to process audio data if present
133+ # Don't return here as we may want to send both transcript and audio
134+
135+ # Read the Content and its first Part
136+ part : Part = (
137+ event .content and event .content .parts and event .content .parts [0 ]
138+ )
139+ if part :
140+ # Audio data must be Base64-encoded for JSON transport
141+ is_audio = part .inline_data and part .inline_data .mime_type .startswith ("audio/pcm" )
142+ if is_audio :
143+ audio_data = part .inline_data and part .inline_data .data
144+ if audio_data :
145+ message = {
146+ "mime_type" : "audio/pcm" ,
147+ "data" : base64 .b64encode (audio_data ).decode ("ascii" )
148+ }
149+ await websocket .send_text (json .dumps (message ))
150+ print (f"[AGENT TO CLIENT]: audio/pcm: { len (audio_data )} bytes." )
151+
152+ # If it's text and a partial text, send it (for cascade audio models or text mode)
153+ if part .text and event .partial :
154+ message = {
155+ "mime_type" : "text/plain" ,
156+ "data" : part .text
157+ }
158+ await websocket .send_text (json .dumps (message ))
159+ print (f"[AGENT TO CLIENT]: text/plain: { message } " )
160+
161+ # If the turn complete or interrupted, send it
162+ if event .turn_complete or event .interrupted :
163+ message = {
164+ "turn_complete" : event .turn_complete ,
165+ "interrupted" : event .interrupted ,
166+ }
167+ await websocket .send_text (json .dumps (message ))
168+ print (f"[AGENT TO CLIENT]: { message } " )
169+ except WebSocketDisconnect :
170+ print ("Client disconnected from agent_to_client_messaging" )
171+ except Exception as e :
172+ print (f"Error in agent_to_client_messaging: { e } " )
129173
130174
131175async def client_to_agent_messaging (websocket , live_request_queue ):
132176 """Client to agent communication"""
133- while True :
134- # Decode JSON message
135- message_json = await websocket .receive_text ()
136- message = json .loads (message_json )
137- mime_type = message ["mime_type" ]
138- data = message ["data" ]
139-
140- # Send the message to the agent
141- if mime_type == "text/plain" :
142- # Send a text message
143- content = Content (role = "user" , parts = [Part .from_text (text = data )])
144- live_request_queue .send_content (content = content )
145- print (f"[CLIENT TO AGENT]: { data } " )
146- elif mime_type == "audio/pcm" :
147- # Send an audio data
148- decoded_data = base64 .b64decode (data )
149- live_request_queue .send_realtime (Blob (data = decoded_data , mime_type = mime_type ))
150- else :
151- raise ValueError (f"Mime type not supported: { mime_type } " )
177+ try :
178+ while True :
179+ message_json = await websocket .receive_text ()
180+ message = json .loads (message_json )
181+ mime_type = message ["mime_type" ]
182+ data = message ["data" ]
183+
184+ if mime_type == "text/plain" :
185+ # send_content() sends text in "turn-by-turn mode"
186+ # This signals a complete turn to the model, triggering immediate response
187+ content = Content (role = "user" , parts = [Part .from_text (text = data )])
188+ live_request_queue .send_content (content = content )
189+ print (f"[CLIENT TO AGENT]: { data } " )
190+ elif mime_type == "audio/pcm" :
191+ # send_realtime() sends audio in "realtime mode"
192+ # Data flows continuously without turn boundaries, enabling natural conversation
193+ # Audio is Base64-encoded for JSON transport, decode before sending
194+ decoded_data = base64 .b64decode (data )
195+ live_request_queue .send_realtime (Blob (data = decoded_data , mime_type = mime_type ))
196+ else :
197+ raise ValueError (f"Mime type not supported: { mime_type } " )
198+ except WebSocketDisconnect :
199+ print ("Client disconnected from client_to_agent_messaging" )
200+ except Exception as e :
201+ print (f"Error in client_to_agent_messaging: { e } " )
152202
153203
154204#
@@ -169,30 +219,39 @@ async def root():
169219
170220@app .websocket ("/ws/{user_id}" )
171221async def websocket_endpoint (websocket : WebSocket , user_id : int , is_audio : str ):
172- """Client websocket endpoint"""
222+ """Client websocket endpoint
223+
224+ This async function creates the LiveRequestQueue in an async context,
225+ which is the recommended best practice from the ADK documentation.
226+ This ensures the queue uses the correct event loop.
227+ """
173228
174- # Wait for client connection
175229 await websocket .accept ()
176230 print (f"Client #{ user_id } connected, audio mode: { is_audio } " )
177231
178- # Start agent session
179232 user_id_str = str (user_id )
180233 live_events , live_request_queue = await start_agent_session (user_id_str , is_audio == "true" )
181234
182- # Start tasks
235+ # Run bidirectional messaging concurrently
183236 agent_to_client_task = asyncio .create_task (
184237 agent_to_client_messaging (websocket , live_events )
185238 )
186239 client_to_agent_task = asyncio .create_task (
187240 client_to_agent_messaging (websocket , live_request_queue )
188241 )
189242
190- # Wait until the websocket is disconnected or an error occurs
191- tasks = [agent_to_client_task , client_to_agent_task ]
192- await asyncio .wait (tasks , return_when = asyncio .FIRST_EXCEPTION )
193-
194- # Close LiveRequestQueue
195- live_request_queue .close ()
196-
197- # Disconnected
198- print (f"Client #{ user_id } disconnected" )
243+ try :
244+ # Wait for either task to complete (connection close or error)
245+ tasks = [agent_to_client_task , client_to_agent_task ]
246+ done , pending = await asyncio .wait (tasks , return_when = asyncio .FIRST_EXCEPTION )
247+
248+ # Check for errors in completed tasks
249+ for task in done :
250+ if task .exception () is not None :
251+ print (f"Task error for client #{ user_id } : { task .exception ()} " )
252+ import traceback
253+ traceback .print_exception (type (task .exception ()), task .exception (), task .exception ().__traceback__ )
254+ finally :
255+ # Clean up resources (always runs, even if asyncio.wait fails)
256+ live_request_queue .close ()
257+ print (f"Client #{ user_id } disconnected" )
0 commit comments