1- #!/usr/bin/env python3
21# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
32# 2023 Nvidia (authors: Yuekai Zhang)
43# 2023 Recurrent.ai (authors: Songtao Shi)
4645import json
4746import queue # Added
4847import uuid # Added
49- import functools # Added
48+ import functools # Added
5049
5150import os
5251import time
5655import numpy as np
5756import soundfile as sf
5857import tritonclient
59- import tritonclient .grpc .aio as grpcclient_aio # Renamed original import
60- import tritonclient .grpc as grpcclient_sync # Added sync client import
61- from tritonclient .utils import np_to_triton_dtype , InferenceServerException # Added InferenceServerException
58+ import tritonclient .grpc .aio as grpcclient_aio # Renamed original import
59+ import tritonclient .grpc as grpcclient_sync # Added sync client import
60+ from tritonclient .utils import np_to_triton_dtype , InferenceServerException # Added InferenceServerException
6261
6362
6463# --- Added UserData and callback ---
@@ -76,9 +75,10 @@ def get_first_chunk_latency(self):
7675 return self ._first_chunk_time - self ._start_time
7776 return None
7877
78+
7979def callback (user_data , result , error ):
8080 if user_data ._first_chunk_time is None and not error :
81- user_data ._first_chunk_time = time .time () # Record time of first successful chunk
81+ user_data ._first_chunk_time = time .time () # Record time of first successful chunk
8282 if error :
8383 user_data ._completed_requests .put (error )
8484 else :
@@ -206,8 +206,11 @@ def get_args():
206206 "--model-name" ,
207207 type = str ,
208208 default = "f5_tts" ,
209- choices = ["f5_tts" , "spark_tts" , "cosyvoice2" ],
210- help = "triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline" ,
209+ choices = [
210+ "f5_tts" ,
211+ "spark_tts" ,
212+ "cosyvoice2" ],
213+ help = "triton model_repo module name to request" ,
211214 )
212215
213216 parser .add_argument (
@@ -273,13 +276,14 @@ def load_audio(wav_path, target_sample_rate=16000):
273276 waveform = resample (waveform , num_samples )
274277 return waveform , target_sample_rate
275278
279+
276280def prepare_request_input_output (
277- protocol_client , # Can be grpcclient_aio or grpcclient_sync
281+ protocol_client , # Can be grpcclient_aio or grpcclient_sync
278282 waveform ,
279283 reference_text ,
280284 target_text ,
281285 sample_rate = 16000 ,
282- padding_duration : int = None # Optional padding for offline mode
286+ padding_duration : int = None # Optional padding for offline mode
283287):
284288 """Prepares inputs for Triton inference (offline or streaming)."""
285289 assert len (waveform .shape ) == 1 , "waveform should be 1D"
@@ -291,9 +295,9 @@ def prepare_request_input_output(
291295 # Estimate target duration based on text length ratio (crude estimation)
292296 # Avoid division by zero if reference_text is empty
293297 if reference_text :
294- estimated_target_duration = duration / len (reference_text ) * len (target_text )
298+ estimated_target_duration = duration / len (reference_text ) * len (target_text )
295299 else :
296- estimated_target_duration = duration # Assume target duration similar to reference if no text
300+ estimated_target_duration = duration # Assume target duration similar to reference if no text
297301
298302 # Calculate required samples based on estimated total duration
299303 required_total_samples = padding_duration * sample_rate * (
@@ -329,6 +333,7 @@ def prepare_request_input_output(
329333
330334 return inputs , outputs
331335
336+
332337def run_sync_streaming_inference (
333338 sync_triton_client : tritonclient .grpc .InferenceServerClient ,
334339 model_name : str ,
@@ -342,7 +347,7 @@ def run_sync_streaming_inference(
342347):
343348 """Helper function to run the blocking sync streaming call."""
344349 start_time_total = time .time ()
345- user_data .record_start_time () # Record start time for first chunk latency calculation
350+ user_data .record_start_time () # Record start time for first chunk latency calculation
346351
347352 # Establish stream
348353 sync_triton_client .start_stream (callback = functools .partial (callback , user_data ))
@@ -360,27 +365,27 @@ def run_sync_streaming_inference(
360365 audios = []
361366 while True :
362367 try :
363- result = user_data ._completed_requests .get () # Add timeout
368+ result = user_data ._completed_requests .get () # Add timeout
364369 if isinstance (result , InferenceServerException ):
365370 print (f"Received InferenceServerException: { result } " )
366371 sync_triton_client .stop_stream ()
367- return None , None , None # Indicate error
372+ return None , None , None # Indicate error
368373 # Get response metadata
369374 response = result .get_response ()
370375 final = response .parameters ["triton_final_response" ].bool_param
371376 if final is True :
372377 break
373378
374379 audio_chunk = result .as_numpy ("waveform" ).reshape (- 1 )
375- if audio_chunk .size > 0 : # Only append non-empty chunks
376- audios .append (audio_chunk )
380+ if audio_chunk .size > 0 : # Only append non-empty chunks
381+ audios .append (audio_chunk )
377382 else :
378383 print ("Warning: received empty audio chunk." )
379384
380385 except queue .Empty :
381386 print (f"Timeout waiting for response for request id { request_id } " )
382387 sync_triton_client .stop_stream ()
383- return None , None , None # Indicate error
388+ return None , None , None # Indicate error
384389
385390 sync_triton_client .stop_stream ()
386391 end_time_total = time .time ()
@@ -398,19 +403,19 @@ def run_sync_streaming_inference(
398403 # Simplified reconstruction based on client_grpc_streaming.py
399404 if not audios :
400405 print ("Warning: No audio chunks received." )
401- reconstructed_audio = np .array ([], dtype = np .float32 ) # Empty array
406+ reconstructed_audio = np .array ([], dtype = np .float32 ) # Empty array
402407 elif len (audios ) == 1 :
403408 reconstructed_audio = audios [0 ]
404409 else :
405- reconstructed_audio = audios [0 ][:- cross_fade_samples ] # Start with first chunk minus overlap
410+ reconstructed_audio = audios [0 ][:- cross_fade_samples ] # Start with first chunk minus overlap
406411 for i in range (1 , len (audios )):
407- # Cross-fade section
408- cross_faded_overlap = (audios [i ][:cross_fade_samples ] * fade_in +
409- audios [i - 1 ][- cross_fade_samples :] * fade_out )
410- # Middle section of the current chunk
411- middle_part = audios [i ][cross_fade_samples :- cross_fade_samples ]
412- # Concatenate
413- reconstructed_audio = np .concatenate ([reconstructed_audio , cross_faded_overlap , middle_part ])
412+ # Cross-fade section
413+ cross_faded_overlap = (audios [i ][:cross_fade_samples ] * fade_in +
414+ audios [i - 1 ][- cross_fade_samples :] * fade_out )
415+ # Middle section of the current chunk
416+ middle_part = audios [i ][cross_fade_samples :- cross_fade_samples ]
417+ # Concatenate
418+ reconstructed_audio = np .concatenate ([reconstructed_audio , cross_faded_overlap , middle_part ])
414419 # Add the last part of the final chunk
415420 reconstructed_audio = np .concatenate ([reconstructed_audio , audios [- 1 ][- cross_fade_samples :]])
416421
@@ -421,19 +426,19 @@ def run_sync_streaming_inference(
421426 sf .write (audio_save_path , reconstructed_audio , save_sample_rate , "PCM_16" )
422427 else :
423428 print ("Warning: No audio chunks received or reconstructed." )
424- actual_duration = 0 # Set duration to 0 if no audio
429+ actual_duration = 0 # Set duration to 0 if no audio
425430
426431 else :
427- print ("Warning: No audio chunks received." )
428- actual_duration = 0
432+ print ("Warning: No audio chunks received." )
433+ actual_duration = 0
429434
430435 return total_request_latency , first_chunk_latency , actual_duration
431436
432437
433438async def send_streaming (
434439 manifest_item_list : list ,
435440 name : str ,
436- server_url : str , # Changed from sync_triton_client
441+ server_url : str , # Changed from sync_triton_client
437442 protocol_client : types .ModuleType ,
438443 log_interval : int ,
439444 model_name : str ,
@@ -445,11 +450,11 @@ async def send_streaming(
445450 total_duration = 0.0
446451 latency_data = []
447452 task_id = int (name [5 :])
448- sync_triton_client = None # Initialize client variable
453+ sync_triton_client = None # Initialize client variable
449454
450- try : # Wrap in try...finally to ensure client closing
455+ try : # Wrap in try...finally to ensure client closing
451456 print (f"{ name } : Initializing sync client for streaming..." )
452- sync_triton_client = grpcclient_sync .InferenceServerClient (url = server_url , verbose = False ) # Create client here
457+ sync_triton_client = grpcclient_sync .InferenceServerClient (url = server_url , verbose = False ) # Create client here
453458
454459 print (f"{ name } : Starting streaming processing for { len (manifest_item_list )} items." )
455460 for i , item in enumerate (manifest_item_list ):
@@ -491,8 +496,7 @@ async def send_streaming(
491496 latency_data .append ((total_request_latency , first_chunk_latency , actual_duration ))
492497 total_duration += actual_duration
493498 else :
494- print (f"{ name } : Item { i } failed." )
495-
499+ print (f"{ name } : Item { i } failed." )
496500
497501 except FileNotFoundError :
498502 print (f"Error: Audio file not found for item { i } : { item ['audio_filepath' ]} " )
@@ -501,19 +505,18 @@ async def send_streaming(
501505 import traceback
502506 traceback .print_exc ()
503507
504-
505- finally : # Ensure client is closed
508+ finally : # Ensure client is closed
506509 if sync_triton_client :
507510 try :
508511 print (f"{ name } : Closing sync client..." )
509512 sync_triton_client .close ()
510513 except Exception as e :
511514 print (f"{ name } : Error closing sync client: { e } " )
512515
513-
514516 print (f"{ name } : Finished streaming processing. Total duration synthesized: { total_duration :.4f} s" )
515517 return total_duration , latency_data
516518
519+
517520async def send (
518521 manifest_item_list : list ,
519522 name : str ,
@@ -605,6 +608,7 @@ def split_data(data, k):
605608
606609 return result
607610
611+
608612async def main ():
609613 args = get_args ()
610614 url = f"{ args .server_addr } :{ args .server_port } "
@@ -622,7 +626,7 @@ async def main():
622626 # Use the sync client for streaming tasks, handled via asyncio.to_thread
623627 # We will create one sync client instance PER TASK inside send_streaming.
624628 # triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
625- protocol_client = grpcclient_sync # protocol client for input prep
629+ protocol_client = grpcclient_sync # protocol client for input prep
626630 else :
627631 raise ValueError (f"Invalid mode: { args .mode } " )
628632 # --- End Client Initialization ---
@@ -682,11 +686,11 @@ async def main():
682686 )
683687 )
684688 elif args .mode == "streaming" :
685- task = asyncio .create_task (
689+ task = asyncio .create_task (
686690 send_streaming (
687691 manifest_item_list [i ],
688692 name = f"task-{ i } " ,
689- server_url = url , # Pass URL instead of client
693+ server_url = url , # Pass URL instead of client
690694 protocol_client = protocol_client ,
691695 log_interval = args .log_interval ,
692696 model_name = args .model_name ,
@@ -709,16 +713,15 @@ async def main():
709713 for ans in ans_list :
710714 if ans :
711715 total_duration += ans [0 ]
712- latency_data .extend (ans [1 ]) # Use extend for list of lists
716+ latency_data .extend (ans [1 ]) # Use extend for list of lists
713717 else :
714- print ("Warning: A task returned None, possibly due to an error." )
715-
718+ print ("Warning: A task returned None, possibly due to an error." )
716719
717720 if total_duration == 0 :
718721 print ("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles." )
719722 rtf = float ('inf' )
720723 else :
721- rtf = elapsed / total_duration
724+ rtf = elapsed / total_duration
722725
723726 s = f"Mode: { args .mode } \n "
724727 s += f"RTF: { rtf :.4f} \n "
@@ -759,7 +762,7 @@ async def main():
759762 s += f"total_request_latency_99_percentile_ms: { np .percentile (total_latency_list , 99 ) * 1000.0 :.2f} \n "
760763 s += f"average_total_request_latency_ms: { avg_total_latency_ms :.2f} \n "
761764 else :
762- s += "No total request latency data collected.\n "
765+ s += "No total request latency data collected.\n "
763766
764767 s += "\n --- First Chunk Latency ---\n "
765768 if first_chunk_latency_list :
@@ -772,7 +775,7 @@ async def main():
772775 s += f"first_chunk_latency_99_percentile_ms: { np .percentile (first_chunk_latency_list , 99 ) * 1000.0 :.2f} \n "
773776 s += f"average_first_chunk_latency_ms: { avg_first_chunk_latency_ms :.2f} \n "
774777 else :
775- s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n "
778+ s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n "
776779 else :
777780 s += "No latency data collected.\n "
778781 # --- End Statistics Reporting ---
@@ -785,7 +788,7 @@ async def main():
785788 elif args .reference_audio :
786789 name = Path (args .reference_audio ).stem
787790 else :
788- name = "results" # Default name if no manifest/split/audio provided
791+ name = "results" # Default name if no manifest/split/audio provided
789792 with open (f"{ args .log_dir } /rtf-{ name } .txt" , "w" ) as f :
790793 f .write (s )
791794
0 commit comments