Skip to content

Commit

Permalink
update to tensorrt8.2.3
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Jul 4, 2022
1 parent 44ab09b commit 4dcf170
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 23 deletions.
3 changes: 2 additions & 1 deletion tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ set(CMAKE_CXX_FLAGS "-std=c++14 -O1")


link_directories(/usr/local/cuda/lib64)
# set(OpenCV_DIR "/opt/opencv/lib/cmake/opencv4")
# include_directories(/root/build/TensorRT-8.2.5.1/include)
# link_directories(/root/build/TensorRT-8.2.5.1/lib)


find_package(CUDA REQUIRED)
Expand Down
54 changes: 36 additions & 18 deletions tensorrt/trt_dep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,28 @@ TrtSharedEnginePtr parse_to_engine(string onnx_pth, bool use_fp16) {
unsigned int maxBatchSize{1};
int memory_limit = 1U << 30; // 1G

auto builder = TrtUniquePtr<IBuilder>(nvinfer1::createInferBuilder(gLogger));
auto builder = TrtUnqPtr<IBuilder>(nvinfer1::createInferBuilder(gLogger));
if (!builder) {
cout << "create builder failed\n";
std::abort();
}

const auto explicitBatch = 1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto network = TrtUniquePtr<INetworkDefinition>(
auto network = TrtUnqPtr<INetworkDefinition>(
builder->createNetworkV2(explicitBatch));
if (!network) {
cout << "create network failed\n";
std::abort();
}

auto config = TrtUniquePtr<IBuilderConfig>(builder->createBuilderConfig());
auto config = TrtUnqPtr<IBuilderConfig>(builder->createBuilderConfig());
if (!config) {
cout << "create builder config failed\n";
std::abort();
}

auto parser = TrtUniquePtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, gLogger));
auto parser = TrtUnqPtr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, gLogger));
if (!parser) {
cout << "create parser failed\n";
std::abort();
Expand All @@ -84,25 +84,45 @@ TrtSharedEnginePtr parse_to_engine(string onnx_pth, bool use_fp16) {
if (use_fp16 && builder->platformHasFastFp16()) {
config->setFlag(nvinfer1::BuilderFlag::kFP16); // fp16
}
// TODO: see if use dla or int8

auto output = network->getOutput(0);
output->setType(nvinfer1::DataType::kINT32);

cout << " start to build \n";
CudaStreamUnqPtr stream(new cudaStream_t);
if (cudaStreamCreate(stream.get())) {
cout << "create stream failed\n";
std::abort();
}
config->setProfileStream(*stream);

auto plan = TrtUnqPtr<IHostMemory>(builder->buildSerializedNetwork(*network, *config));
if (!plan) {
cout << "serialization failed\n";
std::abort();
}

auto runtime = TrtUnqPtr<IRuntime>(nvinfer1::createInferRuntime(gLogger));
if (!plan) {
cout << "create runtime failed\n";
std::abort();
}

TrtSharedEnginePtr engine = shared_engine_ptr(
builder->buildEngineWithConfig(*network, *config));
runtime->deserializeCudaEngine(plan->data(), plan->size()));
if (!engine) {
cout << "create engine failed\n";
std::abort();
}
cout << "done build engine \n";

return engine;
}


void serialize(TrtSharedEnginePtr engine, string save_path) {

auto trt_stream = TrtUniquePtr<IHostMemory>(engine->serialize());
auto trt_stream = TrtUnqPtr<IHostMemory>(engine->serialize());
if (!trt_stream) {
cout << "serialize engine failed\n";
std::abort();
Expand Down Expand Up @@ -132,7 +152,7 @@ TrtSharedEnginePtr deserialize(string serpth) {
ifile.close();
cout << "model size: " << mdsize << endl;

auto runtime = TrtUniquePtr<IRuntime>(nvinfer1::createInferRuntime(gLogger));
auto runtime = TrtUnqPtr<IRuntime>(nvinfer1::createInferRuntime(gLogger));
TrtSharedEnginePtr engine = shared_engine_ptr(
runtime->deserializeCudaEngine((void*)&buf[0], mdsize, nullptr));
return engine;
Expand All @@ -149,7 +169,7 @@ vector<int> infer_with_engine(TrtSharedEnginePtr engine, vector<float>& data) {
vector<void*> buffs(2);
vector<int> res(out_size);

auto context = TrtUniquePtr<IExecutionContext>(engine->createExecutionContext());
auto context = TrtUnqPtr<IExecutionContext>(engine->createExecutionContext());
if (!context) {
cout << "create execution context failed\n";
std::abort();
Expand All @@ -166,34 +186,32 @@ vector<int> infer_with_engine(TrtSharedEnginePtr engine, vector<float>& data) {
cout << "allocate memory failed\n";
std::abort();
}
cudaStream_t stream;
state = cudaStreamCreate(&stream);
if (state) {
CudaStreamUnqPtr stream(new cudaStream_t);
if (cudaStreamCreate(stream.get())) {
cout << "create stream failed\n";
std::abort();
}

state = cudaMemcpyAsync(
buffs[0], &data[0], in_size * sizeof(float),
cudaMemcpyHostToDevice, stream);
cudaMemcpyHostToDevice, *stream);
if (state) {
cout << "transmit to device failed\n";
std::abort();
}
context->enqueueV2(&buffs[0], stream, nullptr);
context->enqueueV2(&buffs[0], *stream, nullptr);
// context->enqueue(1, &buffs[0], stream, nullptr);
state = cudaMemcpyAsync(
&res[0], buffs[1], out_size * sizeof(int),
cudaMemcpyDeviceToHost, stream);
cudaMemcpyDeviceToHost, *stream);
if (state) {
cout << "transmit to host failed \n";
std::abort();
}
cudaStreamSynchronize(stream);
cudaStreamSynchronize(*stream);

cudaFree(buffs[0]);
cudaFree(buffs[1]);
cudaStreamDestroy(stream);

return res;
}
Expand All @@ -210,7 +228,7 @@ void test_fps_with_engine(TrtSharedEnginePtr engine) {
const int in_size{batchsize * 3 * iH * iW};
const int out_size{batchsize * oH * oW};

auto context = TrtUniquePtr<IExecutionContext>(engine->createExecutionContext());
auto context = TrtUnqPtr<IExecutionContext>(engine->createExecutionContext());
if (!context) {
cout << "create execution context failed\n";
std::abort();
Expand Down
14 changes: 10 additions & 4 deletions tensorrt/trt_dep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "NvOnnxParser.h"
#include "NvInferPlugin.h"
#include <cuda_runtime_api.h>
#include "NvInferRuntimeCommon.h"

#include <iostream>
#include <string>
Expand All @@ -25,7 +24,7 @@ using Severity = nvinfer1::ILogger::Severity;

class Logger: public ILogger {
public:
void log(Severity severity, const char* msg) override {
void log(Severity severity, const char* msg) noexcept override {
if (severity != Severity::kINFO) {
std::cout << msg << std::endl;
}
Expand All @@ -35,12 +34,19 @@ class Logger: public ILogger {
struct TrtDeleter {
template <typename T>
void operator()(T* obj) const {
if (obj) {obj->destroy();}
delete obj;
}
};

struct CudaStreamDeleter {
void operator()(cudaStream_t* stream) const {
cudaStreamDestroy(*stream);
}
};

template <typename T>
using TrtUniquePtr = std::unique_ptr<T, TrtDeleter>;
using TrtUnqPtr = std::unique_ptr<T, TrtDeleter>;
using CudaStreamUnqPtr = std::unique_ptr<cudaStream_t, CudaStreamDeleter>;
using TrtSharedEnginePtr = std::shared_ptr<ICudaEngine>;


Expand Down

0 comments on commit 4dcf170

Please sign in to comment.