Skip to content

Commit 8c69f8d

Browse files
committed
fix perf bug
1 parent 0853838 commit 8c69f8d

File tree

3 files changed

+91
-46
lines changed

3 files changed

+91
-46
lines changed

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def constant_fold(
3636
# The constants are created on CPU to save GPU memory for TensorRT compilation.
3737
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
3838
for node, constant in cf.node_replacements.items():
39+
if node.target == torch.ops.aten.embedding.default:
40+
continue
3941
replace_node_with_constant(
4042
gm, node, torch.nn.Parameter(constant, requires_grad=False)
4143
)
@@ -103,7 +105,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
103105
self.quantization_ops: Set[torch._ops.OpOverload] = set()
104106
try:
105107
# modelopt import ensures torch.ops.tensorrt.quantize_op.default is registered
106-
import modelopt.torch.quantization as mtq
108+
import modelopt.torch.quantization as mtq # noqa: F401
107109

108110
assert torch.ops.tensorrt.quantize_op.default
109111
self.quantization_ops.add(torch.ops.tensorrt.quantize_op.default)

tools/perf/perf_run.py

Lines changed: 86 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,7 @@ def run_ts_trt(model, input_tensors, params, precision, batch_size):
174174
compile_settings = {
175175
"inputs": input_tensors,
176176
"enabled_precisions": {precision_to_dtype(precision)},
177-
"truncate_long_and_double": params.get("truncate", False),
178-
"use_python_runtime": params.get("use_python_runtime", False),
177+
"truncate_double": params.get("truncate", False),
179178
}
180179

181180
if precision == "int8":
@@ -274,8 +273,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
274273
ir="dynamo",
275274
enabled_precisions={precision_to_dtype(precision)},
276275
min_block_size=params.get("min_block_size", 1),
277-
debug=False,
278-
truncate_long_and_double=params.get("truncate", False),
276+
truncate_double=params.get("truncate", False),
279277
immutable_weights=params.get("immutable_weights", True),
280278
strip_engine_weights=params.get("strip_engine_weights", False),
281279
refit_identical_engine_weights=params.get(
@@ -284,6 +282,7 @@ def run_dynamo(model, input_tensors, params, precision, batch_size):
284282
cache_built_engines=params.get("cache_built_engines", False),
285283
reuse_cached_engines=params.get("reuse_cached_engines", False),
286284
use_python_runtime=params.get("use_python_runtime", False),
285+
optimization_level=5,
287286
)
288287
end_compile = timeit.default_timer()
289288
compile_time_s = end_compile - start_compile
@@ -437,57 +436,97 @@ def run_tensorrt(
437436
precision,
438437
batch_size=1,
439438
):
440-
# Export an ONNX model and convert to TRT
441-
torch.onnx.export(model.eval().cuda(), tuple(input_tensors), "./tmp.onnx")
442439
logger = trt.Logger(trt.Logger.WARNING)
443-
builder = trt.Builder(logger)
444-
network = builder.create_network(
445-
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
446-
)
447-
parser = trt.OnnxParser(network, logger)
448-
success = parser.parse_from_file("./tmp.onnx")
449-
if not success:
450-
raise ValueError("ONNX conversion failed")
451-
452-
config = builder.create_builder_config()
453-
if precision == "fp16":
454-
config.set_flag(trt.BuilderFlag.FP16)
455-
start_compile = timeit.default_timer()
456-
serialized_engine = builder.build_serialized_network(network, config)
457-
end_compile = timeit.default_timer()
458-
compile_time_s = end_compile - start_compile
440+
compile_time_s = 0
441+
if params["is_trt_engine"]:
442+
serialized_engine = model
443+
else:
444+
# Export an ONNX model and convert to TRT
445+
torch.onnx.export(model.eval().cuda(), tuple(input_tensors), "./tmp.onnx")
446+
builder = trt.Builder(logger)
447+
network = builder.create_network(
448+
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
449+
)
450+
parser = trt.OnnxParser(network, logger)
451+
success = parser.parse_from_file("./tmp.onnx")
452+
if not success:
453+
raise ValueError("ONNX conversion failed")
454+
455+
config = builder.create_builder_config()
456+
if precision == "fp16":
457+
config.set_flag(trt.BuilderFlag.FP16)
458+
config.builder_optimization_level = 5
459+
start_compile = timeit.default_timer()
460+
serialized_engine = builder.build_serialized_network(network, config)
461+
end_compile = timeit.default_timer()
462+
compile_time_s = end_compile - start_compile
459463
# Deserialize the TensorRT engine
460464
with trt.Runtime(logger) as runtime:
461465
engine = runtime.deserialize_cuda_engine(serialized_engine)
462466

463467
print("Running TensorRT for precision: ", precision, " batch_size : ", batch_size)
464468
iters = params.get("iterations", 20)
465469

466-
# Compiling the bindings
467-
bindings = engine.num_bindings * [None]
468-
k = 0
469-
for idx, _ in enumerate(bindings):
470-
dtype = torch_dtype_from_trt(engine.get_binding_dtype(idx))
471-
shape = tuple(engine.get_binding_shape(idx))
472-
device = torch_device_from_trt(engine.get_location(idx))
473-
if not engine.binding_is_input(idx):
474-
# Output bindings
475-
output = torch.empty(size=shape, dtype=dtype, device=device)
476-
bindings[idx] = output.data_ptr()
477-
else:
478-
# Input bindings
479-
bindings[idx] = input_tensors[k].data_ptr()
480-
k += 1
470+
# Get I/O tensor information using TensorRT 10 API
471+
input_names = []
472+
output_names = []
473+
input_dtypes = []
474+
output_dtypes = []
475+
input_shapes = []
476+
output_shapes = []
477+
478+
for i in range(engine.num_io_tensors):
479+
tensor_name = engine.get_tensor_name(i)
480+
tensor_mode = engine.get_tensor_mode(tensor_name)
481+
tensor_dtype = engine.get_tensor_dtype(tensor_name)
482+
tensor_shape = engine.get_tensor_shape(tensor_name)
483+
484+
if tensor_mode == trt.TensorIOMode.INPUT:
485+
input_names.append(tensor_name)
486+
input_dtypes.append(torch_dtype_from_trt(tensor_dtype))
487+
input_shapes.append(tuple(tensor_shape))
488+
else: # trt.TensorIOMode.OUTPUT
489+
output_names.append(tensor_name)
490+
output_dtypes.append(torch_dtype_from_trt(tensor_dtype))
491+
output_shapes.append(tuple(tensor_shape))
492+
493+
# Create output tensors
494+
output_tensors = []
495+
for i, (shape, dtype) in enumerate(zip(output_shapes, output_dtypes)):
496+
output = torch.empty(size=shape, dtype=dtype, device="cuda")
497+
output_tensors.append(output)
481498

482499
timings = []
483500
with engine.create_execution_context() as context:
501+
# Set input tensor addresses
502+
for i, (input_name, input_tensor) in enumerate(zip(input_names, input_tensors)):
503+
context.set_tensor_address(input_name, input_tensor.data_ptr())
504+
505+
# Set output tensor addresses
506+
for output_name, output_tensor in zip(output_names, output_tensors):
507+
context.set_tensor_address(output_name, output_tensor.data_ptr())
508+
509+
# Create a dedicated stream for TensorRT execution
510+
dedicated_stream = torch.cuda.Stream()
511+
current_stream = torch.cuda.current_stream()
512+
513+
# Warm up
484514
for i in range(WARMUP_ITER):
485-
context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream)
515+
# Wait for current stream to finish
516+
dedicated_stream.wait_stream(current_stream)
517+
context.execute_async_v3(dedicated_stream.cuda_stream)
518+
# Wait for TensorRT stream to finish
519+
current_stream.wait_stream(dedicated_stream)
486520
torch.cuda.synchronize()
487521

522+
# Performance measurement
488523
for i in range(iters):
489524
start_time = timeit.default_timer()
490-
context.execute_async_v2(bindings, torch.cuda.current_stream().cuda_stream)
525+
# Wait for current stream to finish
526+
dedicated_stream.wait_stream(current_stream)
527+
context.execute_async_v3(dedicated_stream.cuda_stream)
528+
# Wait for TensorRT stream to finish
529+
current_stream.wait_stream(dedicated_stream)
491530
torch.cuda.synchronize()
492531
end_time = timeit.default_timer()
493532
meas_time = end_time - start_time
@@ -504,7 +543,6 @@ def run(
504543
params,
505544
precision,
506545
batch_size=1,
507-
is_trt_engine=False,
508546
model_torch=None,
509547
):
510548
for backend in backends:
@@ -551,7 +589,6 @@ def run(
551589
input_tensors,
552590
params,
553591
precision,
554-
is_trt_engine,
555592
batch_size,
556593
)
557594
run_dynamo(model_torch, input_tensors, params, precision, batch_size)
@@ -569,7 +606,7 @@ def run(
569606
)
570607
elif backend == "tensorrt":
571608
run_tensorrt(
572-
model_torch,
609+
model,
573610
input_tensors,
574611
params,
575612
precision,
@@ -702,8 +739,13 @@ def run(
702739

703740
# Load TorchScript model, if provided
704741
if os.path.exists(model_name):
705-
print("Loading user provided torchscript model: ", model_name)
706-
model = torch.jit.load(model_name).cuda().eval()
742+
if params["is_trt_engine"]:
743+
with open(model_name, "rb") as f:
744+
model = f.read()
745+
print("Loading user provided trt engine: ", model_name)
746+
else:
747+
print("Loading user provided torchscript model: ", model_name)
748+
model = torch.jit.load(model_name).cuda().eval()
707749

708750
# Load PyTorch Model, if provided
709751
if len(model_name_torch) > 0 and os.path.exists(model_name_torch):
@@ -746,7 +788,6 @@ def run(
746788
params,
747789
precision,
748790
batch_size,
749-
is_trt_engine,
750791
model_torch=model_torch,
751792
)
752793

tools/perf/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def torch_dtype_from_trt(dtype):
176176
return torch.bool
177177
elif dtype == trt.int32:
178178
return torch.int32
179+
elif dtype == trt.int64:
180+
return torch.int64
179181
elif dtype == trt.float16:
180182
return torch.float16
181183
elif dtype == trt.float32:

0 commit comments

Comments
 (0)