Skip to content

Commit 07cbc51

Browse files
committed
fix lint
1 parent 1b8d194 commit 07cbc51

File tree

8 files changed

+165
-157
lines changed

8 files changed

+165
-157
lines changed

runtime/triton_trtllm/client_grpc.py

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python3
21
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
32
# 2023 Nvidia (authors: Yuekai Zhang)
43
# 2023 Recurrent.ai (authors: Songtao Shi)
@@ -46,7 +45,7 @@
4645
import json
4746
import queue # Added
4847
import uuid # Added
49-
import functools # Added
48+
import functools # Added
5049

5150
import os
5251
import time
@@ -56,9 +55,9 @@
5655
import numpy as np
5756
import soundfile as sf
5857
import tritonclient
59-
import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
60-
import tritonclient.grpc as grpcclient_sync # Added sync client import
61-
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
58+
import tritonclient.grpc.aio as grpcclient_aio # Renamed original import
59+
import tritonclient.grpc as grpcclient_sync # Added sync client import
60+
from tritonclient.utils import np_to_triton_dtype, InferenceServerException # Added InferenceServerException
6261

6362

6463
# --- Added UserData and callback ---
@@ -76,9 +75,10 @@ def get_first_chunk_latency(self):
7675
return self._first_chunk_time - self._start_time
7776
return None
7877

78+
7979
def callback(user_data, result, error):
8080
if user_data._first_chunk_time is None and not error:
81-
user_data._first_chunk_time = time.time() # Record time of first successful chunk
81+
user_data._first_chunk_time = time.time() # Record time of first successful chunk
8282
if error:
8383
user_data._completed_requests.put(error)
8484
else:
@@ -206,8 +206,11 @@ def get_args():
206206
"--model-name",
207207
type=str,
208208
default="f5_tts",
209-
choices=["f5_tts", "spark_tts", "cosyvoice2"],
210-
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
209+
choices=[
210+
"f5_tts",
211+
"spark_tts",
212+
"cosyvoice2"],
213+
help="triton model_repo module name to request",
211214
)
212215

213216
parser.add_argument(
@@ -273,13 +276,14 @@ def load_audio(wav_path, target_sample_rate=16000):
273276
waveform = resample(waveform, num_samples)
274277
return waveform, target_sample_rate
275278

279+
276280
def prepare_request_input_output(
277-
protocol_client, # Can be grpcclient_aio or grpcclient_sync
281+
protocol_client, # Can be grpcclient_aio or grpcclient_sync
278282
waveform,
279283
reference_text,
280284
target_text,
281285
sample_rate=16000,
282-
padding_duration: int = None # Optional padding for offline mode
286+
padding_duration: int = None # Optional padding for offline mode
283287
):
284288
"""Prepares inputs for Triton inference (offline or streaming)."""
285289
assert len(waveform.shape) == 1, "waveform should be 1D"
@@ -291,9 +295,9 @@ def prepare_request_input_output(
291295
# Estimate target duration based on text length ratio (crude estimation)
292296
# Avoid division by zero if reference_text is empty
293297
if reference_text:
294-
estimated_target_duration = duration / len(reference_text) * len(target_text)
298+
estimated_target_duration = duration / len(reference_text) * len(target_text)
295299
else:
296-
estimated_target_duration = duration # Assume target duration similar to reference if no text
300+
estimated_target_duration = duration # Assume target duration similar to reference if no text
297301

298302
# Calculate required samples based on estimated total duration
299303
required_total_samples = padding_duration * sample_rate * (
@@ -329,6 +333,7 @@ def prepare_request_input_output(
329333

330334
return inputs, outputs
331335

336+
332337
def run_sync_streaming_inference(
333338
sync_triton_client: tritonclient.grpc.InferenceServerClient,
334339
model_name: str,
@@ -342,7 +347,7 @@ def run_sync_streaming_inference(
342347
):
343348
"""Helper function to run the blocking sync streaming call."""
344349
start_time_total = time.time()
345-
user_data.record_start_time() # Record start time for first chunk latency calculation
350+
user_data.record_start_time() # Record start time for first chunk latency calculation
346351

347352
# Establish stream
348353
sync_triton_client.start_stream(callback=functools.partial(callback, user_data))
@@ -360,27 +365,27 @@ def run_sync_streaming_inference(
360365
audios = []
361366
while True:
362367
try:
363-
result = user_data._completed_requests.get() # Add timeout
368+
result = user_data._completed_requests.get() # Add timeout
364369
if isinstance(result, InferenceServerException):
365370
print(f"Received InferenceServerException: {result}")
366371
sync_triton_client.stop_stream()
367-
return None, None, None # Indicate error
372+
return None, None, None # Indicate error
368373
# Get response metadata
369374
response = result.get_response()
370375
final = response.parameters["triton_final_response"].bool_param
371376
if final is True:
372377
break
373378

374379
audio_chunk = result.as_numpy("waveform").reshape(-1)
375-
if audio_chunk.size > 0: # Only append non-empty chunks
376-
audios.append(audio_chunk)
380+
if audio_chunk.size > 0: # Only append non-empty chunks
381+
audios.append(audio_chunk)
377382
else:
378383
print("Warning: received empty audio chunk.")
379384

380385
except queue.Empty:
381386
print(f"Timeout waiting for response for request id {request_id}")
382387
sync_triton_client.stop_stream()
383-
return None, None, None # Indicate error
388+
return None, None, None # Indicate error
384389

385390
sync_triton_client.stop_stream()
386391
end_time_total = time.time()
@@ -398,19 +403,19 @@ def run_sync_streaming_inference(
398403
# Simplified reconstruction based on client_grpc_streaming.py
399404
if not audios:
400405
print("Warning: No audio chunks received.")
401-
reconstructed_audio = np.array([], dtype=np.float32) # Empty array
406+
reconstructed_audio = np.array([], dtype=np.float32) # Empty array
402407
elif len(audios) == 1:
403408
reconstructed_audio = audios[0]
404409
else:
405-
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
410+
reconstructed_audio = audios[0][:-cross_fade_samples] # Start with first chunk minus overlap
406411
for i in range(1, len(audios)):
407-
# Cross-fade section
408-
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
409-
audios[i - 1][-cross_fade_samples:] * fade_out)
410-
# Middle section of the current chunk
411-
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
412-
# Concatenate
413-
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
412+
# Cross-fade section
413+
cross_faded_overlap = (audios[i][:cross_fade_samples] * fade_in +
414+
audios[i - 1][-cross_fade_samples:] * fade_out)
415+
# Middle section of the current chunk
416+
middle_part = audios[i][cross_fade_samples:-cross_fade_samples]
417+
# Concatenate
418+
reconstructed_audio = np.concatenate([reconstructed_audio, cross_faded_overlap, middle_part])
414419
# Add the last part of the final chunk
415420
reconstructed_audio = np.concatenate([reconstructed_audio, audios[-1][-cross_fade_samples:]])
416421

@@ -421,19 +426,19 @@ def run_sync_streaming_inference(
421426
sf.write(audio_save_path, reconstructed_audio, save_sample_rate, "PCM_16")
422427
else:
423428
print("Warning: No audio chunks received or reconstructed.")
424-
actual_duration = 0 # Set duration to 0 if no audio
429+
actual_duration = 0 # Set duration to 0 if no audio
425430

426431
else:
427-
print("Warning: No audio chunks received.")
428-
actual_duration = 0
432+
print("Warning: No audio chunks received.")
433+
actual_duration = 0
429434

430435
return total_request_latency, first_chunk_latency, actual_duration
431436

432437

433438
async def send_streaming(
434439
manifest_item_list: list,
435440
name: str,
436-
server_url: str, # Changed from sync_triton_client
441+
server_url: str, # Changed from sync_triton_client
437442
protocol_client: types.ModuleType,
438443
log_interval: int,
439444
model_name: str,
@@ -445,11 +450,11 @@ async def send_streaming(
445450
total_duration = 0.0
446451
latency_data = []
447452
task_id = int(name[5:])
448-
sync_triton_client = None # Initialize client variable
453+
sync_triton_client = None # Initialize client variable
449454

450-
try: # Wrap in try...finally to ensure client closing
455+
try: # Wrap in try...finally to ensure client closing
451456
print(f"{name}: Initializing sync client for streaming...")
452-
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
457+
sync_triton_client = grpcclient_sync.InferenceServerClient(url=server_url, verbose=False) # Create client here
453458

454459
print(f"{name}: Starting streaming processing for {len(manifest_item_list)} items.")
455460
for i, item in enumerate(manifest_item_list):
@@ -491,8 +496,7 @@ async def send_streaming(
491496
latency_data.append((total_request_latency, first_chunk_latency, actual_duration))
492497
total_duration += actual_duration
493498
else:
494-
print(f"{name}: Item {i} failed.")
495-
499+
print(f"{name}: Item {i} failed.")
496500

497501
except FileNotFoundError:
498502
print(f"Error: Audio file not found for item {i}: {item['audio_filepath']}")
@@ -501,19 +505,18 @@ async def send_streaming(
501505
import traceback
502506
traceback.print_exc()
503507

504-
505-
finally: # Ensure client is closed
508+
finally: # Ensure client is closed
506509
if sync_triton_client:
507510
try:
508511
print(f"{name}: Closing sync client...")
509512
sync_triton_client.close()
510513
except Exception as e:
511514
print(f"{name}: Error closing sync client: {e}")
512515

513-
514516
print(f"{name}: Finished streaming processing. Total duration synthesized: {total_duration:.4f}s")
515517
return total_duration, latency_data
516518

519+
517520
async def send(
518521
manifest_item_list: list,
519522
name: str,
@@ -605,6 +608,7 @@ def split_data(data, k):
605608

606609
return result
607610

611+
608612
async def main():
609613
args = get_args()
610614
url = f"{args.server_addr}:{args.server_port}"
@@ -622,7 +626,7 @@ async def main():
622626
# Use the sync client for streaming tasks, handled via asyncio.to_thread
623627
# We will create one sync client instance PER TASK inside send_streaming.
624628
# triton_client = grpcclient_sync.InferenceServerClient(url=url, verbose=False) # REMOVED: Client created per task now
625-
protocol_client = grpcclient_sync # protocol client for input prep
629+
protocol_client = grpcclient_sync # protocol client for input prep
626630
else:
627631
raise ValueError(f"Invalid mode: {args.mode}")
628632
# --- End Client Initialization ---
@@ -682,11 +686,11 @@ async def main():
682686
)
683687
)
684688
elif args.mode == "streaming":
685-
task = asyncio.create_task(
689+
task = asyncio.create_task(
686690
send_streaming(
687691
manifest_item_list[i],
688692
name=f"task-{i}",
689-
server_url=url, # Pass URL instead of client
693+
server_url=url, # Pass URL instead of client
690694
protocol_client=protocol_client,
691695
log_interval=args.log_interval,
692696
model_name=args.model_name,
@@ -709,16 +713,15 @@ async def main():
709713
for ans in ans_list:
710714
if ans:
711715
total_duration += ans[0]
712-
latency_data.extend(ans[1]) # Use extend for list of lists
716+
latency_data.extend(ans[1]) # Use extend for list of lists
713717
else:
714-
print("Warning: A task returned None, possibly due to an error.")
715-
718+
print("Warning: A task returned None, possibly due to an error.")
716719

717720
if total_duration == 0:
718721
print("Total synthesized duration is zero. Cannot calculate RTF or latency percentiles.")
719722
rtf = float('inf')
720723
else:
721-
rtf = elapsed / total_duration
724+
rtf = elapsed / total_duration
722725

723726
s = f"Mode: {args.mode}\n"
724727
s += f"RTF: {rtf:.4f}\n"
@@ -759,7 +762,7 @@ async def main():
759762
s += f"total_request_latency_99_percentile_ms: {np.percentile(total_latency_list, 99) * 1000.0:.2f}\n"
760763
s += f"average_total_request_latency_ms: {avg_total_latency_ms:.2f}\n"
761764
else:
762-
s += "No total request latency data collected.\n"
765+
s += "No total request latency data collected.\n"
763766

764767
s += "\n--- First Chunk Latency ---\n"
765768
if first_chunk_latency_list:
@@ -772,7 +775,7 @@ async def main():
772775
s += f"first_chunk_latency_99_percentile_ms: {np.percentile(first_chunk_latency_list, 99) * 1000.0:.2f}\n"
773776
s += f"average_first_chunk_latency_ms: {avg_first_chunk_latency_ms:.2f}\n"
774777
else:
775-
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
778+
s += "No first chunk latency data collected (check for errors or if all requests failed before first chunk).\n"
776779
else:
777780
s += "No latency data collected.\n"
778781
# --- End Statistics Reporting ---
@@ -785,7 +788,7 @@ async def main():
785788
elif args.reference_audio:
786789
name = Path(args.reference_audio).stem
787790
else:
788-
name = "results" # Default name if no manifest/split/audio provided
791+
name = "results" # Default name if no manifest/split/audio provided
789792
with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f:
790793
f.write(s)
791794

runtime/triton_trtllm/client_http.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import numpy as np
3030
import argparse
3131

32+
3233
def get_args():
3334
parser = argparse.ArgumentParser(
3435
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@@ -67,9 +68,10 @@ def get_args():
6768
type=str,
6869
default="spark_tts",
6970
choices=[
70-
"f5_tts", "spark_tts", "cosyvoice2"
71-
],
72-
help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline",
71+
"f5_tts",
72+
"spark_tts",
73+
"cosyvoice2"],
74+
help="triton model_repo module name to request",
7375
)
7476

7577
parser.add_argument(
@@ -80,6 +82,7 @@ def get_args():
8082
)
8183
return parser.parse_args()
8284

85+
8386
def prepare_request(
8487
waveform,
8588
reference_text,
@@ -97,19 +100,19 @@ def prepare_request(
97100
1,
98101
padding_duration
99102
* sample_rate
100-
* ((int(duration) // padding_duration) + 1),
103+
* ((int(len(waveform) / sample_rate) // padding_duration) + 1),
101104
),
102105
dtype=np.float32,
103106
)
104107

105108
samples[0, : len(waveform)] = waveform
106109
else:
107110
samples = waveform
108-
111+
109112
samples = samples.reshape(1, -1).astype(np.float32)
110113

111114
data = {
112-
"inputs":[
115+
"inputs": [
113116
{
114117
"name": "reference_wav",
115118
"shape": samples.shape,
@@ -139,16 +142,17 @@ def prepare_request(
139142

140143
return data
141144

145+
142146
if __name__ == "__main__":
143147
args = get_args()
144148
server_url = args.server_url
145149
if not server_url.startswith(("http://", "https://")):
146150
server_url = f"http://{server_url}"
147-
151+
148152
url = f"{server_url}/v2/models/{args.model_name}/infer"
149153
waveform, sr = sf.read(args.reference_audio)
150154
assert sr == 16000, "sample rate hardcoded in server"
151-
155+
152156
samples = np.array(waveform, dtype=np.float32)
153157
data = prepare_request(samples, args.reference_text, args.target_text)
154158

@@ -166,4 +170,4 @@ def prepare_request(
166170
sample_rate = 16000
167171
else:
168172
sample_rate = 24000
169-
sf.write(args.output_audio, audio, sample_rate, "PCM_16")
173+
sf.write(args.output_audio, audio, sample_rate, "PCM_16")

0 commit comments

Comments
 (0)