diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 755a0265..7a3548ad 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -3,6 +3,7 @@ import copy import numpy as np import io +import os from collections import defaultdict import importlib @@ -538,6 +539,7 @@ def torch2trt(module, onnx_opset=None, max_batch_size=None, avg_timing_iterations=None, + engine_file_path = './trt_engine/engine.plan', **kwargs): # capture arguments to provide to context @@ -559,151 +561,176 @@ def torch2trt(module, input_flattener = Flattener.from_value(inputs) output_flattener = Flattener.from_value(outputs) - # infer default parameters from dataset - - if min_shapes == None: - min_shapes_flat = [tuple(t) for t in dataset.min_shapes(flat=True)] - else: - min_shapes_flat = input_flattener.flatten(min_shapes) - - if max_shapes == None: - max_shapes_flat = [tuple(t) for t in dataset.max_shapes(flat=True)] - else: - max_shapes_flat = input_flattener.flatten(max_shapes) - - if opt_shapes == None: - opt_shapes_flat = [tuple(t) for t in dataset.median_numel_shapes(flat=True)] - else: - opt_shapes_flat = input_flattener.flatten(opt_shapes) - - # handle legacy max_batch_size - if max_batch_size is not None: - min_shapes_flat = [(1,) + s[1:] for s in min_shapes_flat] - max_shapes_flat = [(max_batch_size,) + s[1:] for s in max_shapes_flat] - - dynamic_axes_flat = infer_dynamic_axes(min_shapes_flat, max_shapes_flat) - - if default_device_type == trt.DeviceType.DLA: - for value in dynamic_axes_flat: - if len(value) > 0: - raise ValueError('Dataset cannot have multiple shapes when using DLA') - - logger = trt.Logger(log_level) - builder = trt.Builder(logger) - config = builder.create_builder_config() - if input_names is None: input_names = default_input_names(input_flattener.size) if output_names is None: output_names = default_output_names(output_flattener.size) - if use_onnx: - import onnx_graphsurgeon as gs - import onnx - - module_flat = Flatten(module, input_flattener, output_flattener) - inputs_flat = input_flattener.flatten(inputs) - - f = io.BytesIO() - torch.onnx.export( - module_flat, - inputs_flat, - f, - input_names=input_names, - output_names=output_names, - dynamic_axes={ - name: {int(axis): f'input_{index}_axis_{axis}' for axis in dynamic_axes_flat[index]} - for index, name in enumerate(input_names) - }, - opset_version=onnx_opset - ) - f.seek(0) + def build_engine(): + # infer default parameters from dataset + + if min_shapes == None: + min_shapes_flat = [tuple(t) for t in dataset.min_shapes(flat=True)] + else: + min_shapes_flat = input_flattener.flatten(min_shapes) + + if max_shapes == None: + max_shapes_flat = [tuple(t) for t in dataset.max_shapes(flat=True)] + else: + max_shapes_flat = input_flattener.flatten(max_shapes) - onnx_graph = gs.import_onnx(onnx.load(f)) - onnx_graph.fold_constants().cleanup() + if opt_shapes == None: + opt_shapes_flat = [tuple(t) for t in dataset.median_numel_shapes(flat=True)] + else: + opt_shapes_flat = input_flattener.flatten(opt_shapes) + # handle legacy max_batch_size + if max_batch_size is not None: + min_shapes_flat = [(1,) + s[1:] for s in min_shapes_flat] + max_shapes_flat = [(max_batch_size,) + s[1:] for s in max_shapes_flat] - f = io.BytesIO() - onnx.save(gs.export_onnx(onnx_graph), f) - f.seek(0) + dynamic_axes_flat = infer_dynamic_axes(min_shapes_flat, max_shapes_flat) - onnx_bytes = f.read() - network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) - parser = trt.OnnxParser(network, logger) - parser.parse(onnx_bytes) + if default_device_type == trt.DeviceType.DLA: + for value in dynamic_axes_flat: + if len(value) > 0: + raise ValueError('Dataset cannot have multiple shapes when using DLA') - else: - network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) - with ConversionContext(network, torch2trt_kwargs=kwargs, builder_config=config, logger=logger) as ctx: - + logger = trt.Logger(log_level) + builder = trt.Builder(logger) + config = builder.create_builder_config() + + if use_onnx: + import onnx_graphsurgeon as gs + import onnx + + module_flat = Flatten(module, input_flattener, output_flattener) inputs_flat = input_flattener.flatten(inputs) - ctx.add_inputs(inputs_flat, input_names, dynamic_axes=dynamic_axes_flat) + f = io.BytesIO() + torch.onnx.export( + module_flat, + inputs_flat, + f, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + name: {int(axis): f'input_{index}_axis_{axis}' for axis in dynamic_axes_flat[index]} + for index, name in enumerate(input_names) + }, + opset_version=onnx_opset + ) + f.seek(0) - outputs = module(*inputs) + onnx_graph = gs.import_onnx(onnx.load(f)) + onnx_graph.fold_constants().cleanup() - outputs_flat = output_flattener.flatten(outputs) - ctx.mark_outputs(outputs_flat, output_names) - # set max workspace size - if trt_version() < "10.0": - config.max_workspace_size = max_workspace_size - + f = io.BytesIO() + onnx.save(gs.export_onnx(onnx_graph), f) + f.seek(0) - # set number of avg timing itrs. - if avg_timing_iterations is not None: - config.avg_timing_iterations = avg_timing_iterations + onnx_bytes = f.read() + network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + parser = trt.OnnxParser(network, logger) + parser.parse(onnx_bytes) - if fp16_mode: - config.set_flag(trt.BuilderFlag.FP16) + else: + network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + with ConversionContext(network, torch2trt_kwargs=kwargs, builder_config=config, logger=logger) as ctx: - config.default_device_type = default_device_type - if gpu_fallback: - config.set_flag(trt.BuilderFlag.GPU_FALLBACK) - config.DLA_core = dla_core - - if strict_type_constraints: - config.set_flag(trt.BuilderFlag.STRICT_TYPES) + inputs_flat = input_flattener.flatten(inputs) + + ctx.add_inputs(inputs_flat, input_names, dynamic_axes=dynamic_axes_flat) + + outputs = module(*inputs) + + outputs_flat = output_flattener.flatten(outputs) + ctx.mark_outputs(outputs_flat, output_names) + + # set max workspace size + if trt_version() < "10.0": + config.max_workspace_size = max_workspace_size + + # set number of avg timing itrs. + if avg_timing_iterations is not None: + config.avg_timing_iterations = avg_timing_iterations + + if fp16_mode: + config.set_flag(trt.BuilderFlag.FP16) - if int8_mode: + config.default_device_type = default_device_type + if gpu_fallback: + config.set_flag(trt.BuilderFlag.GPU_FALLBACK) + config.DLA_core = dla_core - # default to use input tensors for calibration - if int8_calib_dataset is None: - int8_calib_dataset = dataset + if strict_type_constraints: + config.set_flag(trt.BuilderFlag.STRICT_TYPES) - config.set_flag(trt.BuilderFlag.INT8) + if int8_mode: - #Making sure not to run calibration with QAT mode on - if not 'qat_mode' in kwargs: - calibrator = DatasetCalibrator( - int8_calib_dataset, algorithm=int8_calib_algorithm + # default to use input tensors for calibration + if int8_calib_dataset is None: + int8_calib_dataset = dataset + + config.set_flag(trt.BuilderFlag.INT8) + + #Making sure not to run calibration with QAT mode on + if not 'qat_mode' in kwargs: + calibrator = DatasetCalibrator( + int8_calib_dataset, algorithm=int8_calib_algorithm + ) + config.int8_calibrator = calibrator + + # OPTIMIZATION PROFILE + profile = builder.create_optimization_profile() + for index, name in enumerate(input_names): + profile.set_shape( + name, + min_shapes_flat[index], + opt_shapes_flat[index], + max_shapes_flat[index] ) - config.int8_calibrator = calibrator - - # OPTIMIZATION PROFILE - profile = builder.create_optimization_profile() - for index, name in enumerate(input_names): - profile.set_shape( - name, - min_shapes_flat[index], - opt_shapes_flat[index], - max_shapes_flat[index] - ) - config.add_optimization_profile(profile) + config.add_optimization_profile(profile) - if int8_mode: - config.set_calibration_profile(profile) + if int8_mode: + config.set_calibration_profile(profile) - # BUILD ENGINE + # BUILD ENGINE - if trt_version() < "10.0": - engine = builder.build_engine(network, config) - else: - engine = builder.build_serialized_network(network, config) + if trt_version() < "10.0": + engine = builder.build_engine(network, config) + else: + engine = builder.build_serialized_network(network, config) + + # SAVE ENGINE + os.makedirs(os.path.dirname(engine_file_path), exist_ok=True) + if trt_version() < "10.0": + with open(engine_file_path, "wb") as f: + f.write(engine.serialize()) + else: + with open(engine_file_path, "wb") as f: + f.write(engine) + return engine, network + + load_engine = False + if os.path.exists(engine_file_path): + # If a serialized engine exists, use it instead of building an engine. + print("Reading engine from file {}.".format(engine_file_path)) + try: + with open(engine_file_path, + "rb") as f, trt.Logger(log_level) as logger, trt.Runtime(logger) as runtime: + engine = runtime.deserialize_cuda_engine(f.read()) + if engine is not None: + load_engine = True + except: + print("Failed to load engine from file {}. run build_engine().".format(engine_file_path)) + if not load_engine: + engine, network = build_engine() module_trt = TRTModule(engine, input_names, output_names, input_flattener=input_flattener, output_flattener=output_flattener) - if keep_network: + if not load_engine and keep_network: module_trt.network = network return module_trt