@@ -174,8 +174,7 @@ def run_ts_trt(model, input_tensors, params, precision, batch_size):
174
174
compile_settings = {
175
175
"inputs" : input_tensors ,
176
176
"enabled_precisions" : {precision_to_dtype (precision )},
177
- "truncate_long_and_double" : params .get ("truncate" , False ),
178
- "use_python_runtime" : params .get ("use_python_runtime" , False ),
177
+ "truncate_double" : params .get ("truncate" , False ),
179
178
}
180
179
181
180
if precision == "int8" :
@@ -274,8 +273,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
274
273
ir = "dynamo" ,
275
274
enabled_precisions = {precision_to_dtype (precision )},
276
275
min_block_size = params .get ("min_block_size" , 1 ),
277
- debug = False ,
278
- truncate_long_and_double = params .get ("truncate" , False ),
276
+ truncate_double = params .get ("truncate" , False ),
279
277
immutable_weights = params .get ("immutable_weights" , True ),
280
278
strip_engine_weights = params .get ("strip_engine_weights" , False ),
281
279
refit_identical_engine_weights = params .get (
@@ -284,6 +282,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
284
282
cache_built_engines = params .get ("cache_built_engines" , False ),
285
283
reuse_cached_engines = params .get ("reuse_cached_engines" , False ),
286
284
use_python_runtime = params .get ("use_python_runtime" , False ),
285
+ optimization_level = 5 ,
287
286
)
288
287
end_compile = timeit .default_timer ()
289
288
compile_time_s = end_compile - start_compile
@@ -437,57 +436,97 @@ def run_tensorrt(
437
436
precision ,
438
437
batch_size = 1 ,
439
438
):
440
- # Export an ONNX model and convert to TRT
441
- torch .onnx .export (model .eval ().cuda (), tuple (input_tensors ), "./tmp.onnx" )
442
439
logger = trt .Logger (trt .Logger .WARNING )
443
- builder = trt .Builder (logger )
444
- network = builder .create_network (
445
- 1 << int (trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH )
446
- )
447
- parser = trt .OnnxParser (network , logger )
448
- success = parser .parse_from_file ("./tmp.onnx" )
449
- if not success :
450
- raise ValueError ("ONNX conversion failed" )
451
-
452
- config = builder .create_builder_config ()
453
- if precision == "fp16" :
454
- config .set_flag (trt .BuilderFlag .FP16 )
455
- start_compile = timeit .default_timer ()
456
- serialized_engine = builder .build_serialized_network (network , config )
457
- end_compile = timeit .default_timer ()
458
- compile_time_s = end_compile - start_compile
440
+ compile_time_s = 0
441
+ if params ["is_trt_engine" ]:
442
+ serialized_engine = model
443
+ else :
444
+ # Export an ONNX model and convert to TRT
445
+ torch .onnx .export (model .eval ().cuda (), tuple (input_tensors ), "./tmp.onnx" )
446
+ builder = trt .Builder (logger )
447
+ network = builder .create_network (
448
+ 1 << int (trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH )
449
+ )
450
+ parser = trt .OnnxParser (network , logger )
451
+ success = parser .parse_from_file ("./tmp.onnx" )
452
+ if not success :
453
+ raise ValueError ("ONNX conversion failed" )
454
+
455
+ config = builder .create_builder_config ()
456
+ if precision == "fp16" :
457
+ config .set_flag (trt .BuilderFlag .FP16 )
458
+ config .builder_optimization_level = 5
459
+ start_compile = timeit .default_timer ()
460
+ serialized_engine = builder .build_serialized_network (network , config )
461
+ end_compile = timeit .default_timer ()
462
+ compile_time_s = end_compile - start_compile
459
463
# Deserialize the TensorRT engine
460
464
with trt .Runtime (logger ) as runtime :
461
465
engine = runtime .deserialize_cuda_engine (serialized_engine )
462
466
463
467
print ("Running TensorRT for precision: " , precision , " batch_size : " , batch_size )
464
468
iters = params .get ("iterations" , 20 )
465
469
466
- # Compiling the bindings
467
- bindings = engine .num_bindings * [None ]
468
- k = 0
469
- for idx , _ in enumerate (bindings ):
470
- dtype = torch_dtype_from_trt (engine .get_binding_dtype (idx ))
471
- shape = tuple (engine .get_binding_shape (idx ))
472
- device = torch_device_from_trt (engine .get_location (idx ))
473
- if not engine .binding_is_input (idx ):
474
- # Output bindings
475
- output = torch .empty (size = shape , dtype = dtype , device = device )
476
- bindings [idx ] = output .data_ptr ()
477
- else :
478
- # Input bindings
479
- bindings [idx ] = input_tensors [k ].data_ptr ()
480
- k += 1
470
+ # Get I/O tensor information using TensorRT 10 API
471
+ input_names = []
472
+ output_names = []
473
+ input_dtypes = []
474
+ output_dtypes = []
475
+ input_shapes = []
476
+ output_shapes = []
477
+
478
+ for i in range (engine .num_io_tensors ):
479
+ tensor_name = engine .get_tensor_name (i )
480
+ tensor_mode = engine .get_tensor_mode (tensor_name )
481
+ tensor_dtype = engine .get_tensor_dtype (tensor_name )
482
+ tensor_shape = engine .get_tensor_shape (tensor_name )
483
+
484
+ if tensor_mode == trt .TensorIOMode .INPUT :
485
+ input_names .append (tensor_name )
486
+ input_dtypes .append (torch_dtype_from_trt (tensor_dtype ))
487
+ input_shapes .append (tuple (tensor_shape ))
488
+ else : # trt.TensorIOMode.OUTPUT
489
+ output_names .append (tensor_name )
490
+ output_dtypes .append (torch_dtype_from_trt (tensor_dtype ))
491
+ output_shapes .append (tuple (tensor_shape ))
492
+
493
+ # Create output tensors
494
+ output_tensors = []
495
+ for i , (shape , dtype ) in enumerate (zip (output_shapes , output_dtypes )):
496
+ output = torch .empty (size = shape , dtype = dtype , device = "cuda" )
497
+ output_tensors .append (output )
481
498
482
499
timings = []
483
500
with engine .create_execution_context () as context :
501
+ # Set input tensor addresses
502
+ for i , (input_name , input_tensor ) in enumerate (zip (input_names , input_tensors )):
503
+ context .set_tensor_address (input_name , input_tensor .data_ptr ())
504
+
505
+ # Set output tensor addresses
506
+ for output_name , output_tensor in zip (output_names , output_tensors ):
507
+ context .set_tensor_address (output_name , output_tensor .data_ptr ())
508
+
509
+ # Create a dedicated stream for TensorRT execution
510
+ dedicated_stream = torch .cuda .Stream ()
511
+ current_stream = torch .cuda .current_stream ()
512
+
513
+ # Warm up
484
514
for i in range (WARMUP_ITER ):
485
- context .execute_async_v2 (bindings , torch .cuda .current_stream ().cuda_stream )
515
+ # Wait for current stream to finish
516
+ dedicated_stream .wait_stream (current_stream )
517
+ context .execute_async_v3 (dedicated_stream .cuda_stream )
518
+ # Wait for TensorRT stream to finish
519
+ current_stream .wait_stream (dedicated_stream )
486
520
torch .cuda .synchronize ()
487
521
522
+ # Performance measurement
488
523
for i in range (iters ):
489
524
start_time = timeit .default_timer ()
490
- context .execute_async_v2 (bindings , torch .cuda .current_stream ().cuda_stream )
525
+ # Wait for current stream to finish
526
+ dedicated_stream .wait_stream (current_stream )
527
+ context .execute_async_v3 (dedicated_stream .cuda_stream )
528
+ # Wait for TensorRT stream to finish
529
+ current_stream .wait_stream (dedicated_stream )
491
530
torch .cuda .synchronize ()
492
531
end_time = timeit .default_timer ()
493
532
meas_time = end_time - start_time
@@ -504,7 +543,6 @@ def run(
504
543
params ,
505
544
precision ,
506
545
batch_size = 1 ,
507
- is_trt_engine = False ,
508
546
model_torch = None ,
509
547
):
510
548
for backend in backends :
@@ -551,7 +589,6 @@ def run(
551
589
input_tensors ,
552
590
params ,
553
591
precision ,
554
- is_trt_engine ,
555
592
batch_size ,
556
593
)
557
594
run_dynamo (model_torch , input_tensors , params , precision , batch_size )
@@ -569,7 +606,7 @@ def run(
569
606
)
570
607
elif backend == "tensorrt" :
571
608
run_tensorrt (
572
- model_torch ,
609
+ model ,
573
610
input_tensors ,
574
611
params ,
575
612
precision ,
@@ -702,8 +739,13 @@ def run(
702
739
703
740
# Load TorchScript model, if provided
704
741
if os .path .exists (model_name ):
705
- print ("Loading user provided torchscript model: " , model_name )
706
- model = torch .jit .load (model_name ).cuda ().eval ()
742
+ if params ["is_trt_engine" ]:
743
+ with open (model_name , "rb" ) as f :
744
+ model = f .read ()
745
+ print ("Loading user provided trt engine: " , model_name )
746
+ else :
747
+ print ("Loading user provided torchscript model: " , model_name )
748
+ model = torch .jit .load (model_name ).cuda ().eval ()
707
749
708
750
# Load PyTorch Model, if provided
709
751
if len (model_name_torch ) > 0 and os .path .exists (model_name_torch ):
@@ -746,7 +788,6 @@ def run(
746
788
params ,
747
789
precision ,
748
790
batch_size ,
749
- is_trt_engine ,
750
791
model_torch = model_torch ,
751
792
)
752
793
0 commit comments