forked from onnx/onnx-tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNvOnnxParser.h
264 lines (235 loc) · 8.14 KB
/
NvOnnxParser.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
/*
* SPDX-License-Identifier: Apache-2.0
*/
#ifndef NV_ONNX_PARSER_H
#define NV_ONNX_PARSER_H
#include "NvInfer.h"
#include <stddef.h>
#include <vector>
//!
//! \file NvOnnxParser.h
//!
//! This is the API for the ONNX Parser
//!
#define NV_ONNX_PARSER_MAJOR 0
#define NV_ONNX_PARSER_MINOR 1
#define NV_ONNX_PARSER_PATCH 0
static const int NV_ONNX_PARSER_VERSION = ((NV_ONNX_PARSER_MAJOR * 10000) + (NV_ONNX_PARSER_MINOR * 100) + NV_ONNX_PARSER_PATCH);
//! \typedef SubGraph_t
//!
//! \brief The data structure containing the parsing capability of
//! a set of nodes in an ONNX graph.
//!
using SubGraph_t = std::pair<std::vector<size_t>, bool>;
//! \typedef SubGraphCollection_t
//!
//! \brief The data structure containing all SubGraph_t partitioned
//! out of an ONNX graph.
//!
using SubGraphCollection_t = std::vector<SubGraph_t>;
class onnxTensorDescriptorV1;
//!
//! \namespace nvonnxparser
//!
//! \brief The TensorRT ONNX parser API namespace
//!
namespace nvonnxparser
{
template <typename T>
inline int32_t EnumMax();
/** \enum ErrorCode
*
* \brief the type of parser error
*/
enum class ErrorCode : int
{
kSUCCESS = 0,
kINTERNAL_ERROR = 1,
kMEM_ALLOC_FAILED = 2,
kMODEL_DESERIALIZE_FAILED = 3,
kINVALID_VALUE = 4,
kINVALID_GRAPH = 5,
kINVALID_NODE = 6,
kUNSUPPORTED_GRAPH = 7,
kUNSUPPORTED_NODE = 8
};
template <>
inline int32_t EnumMax<ErrorCode>()
{
return 9;
}
/** \class IParserError
*
* \brief an object containing information about an error
*/
class IParserError
{
public:
/** \brief the error code
*/
virtual ErrorCode code() const = 0;
/** \brief description of the error
*/
virtual const char* desc() const = 0;
/** \brief source file in which the error occurred
*/
virtual const char* file() const = 0;
/** \brief source line at which the error occurred
*/
virtual int line() const = 0;
/** \brief source function in which the error occurred
*/
virtual const char* func() const = 0;
/** \brief index of the ONNX model node in which the error occurred
*/
virtual int node() const = 0;
protected:
virtual ~IParserError() {}
};
/** \class IParser
*
* \brief an object for parsing ONNX models into a TensorRT network definition
*/
class IParser
{
public:
/** \brief Parse a serialized ONNX model into the TensorRT network.
* This method has very limited diagnostic. If parsing the serialized model
* fails for any reason (e.g. unsupported IR version, unsupported opset, etc.)
* it the user responsibility to intercept and report the error.
* To obtain a better diagnostic, use the parseFromFile method below.
*
* \param serialized_onnx_model Pointer to the serialized ONNX model
* \param serialized_onnx_model_size Size of the serialized ONNX model
* in bytes
* \param model_path Absolute path to the model file for loading external weights if required
* \return true if the model was parsed successfully
* \see getNbErrors() getError()
*/
virtual bool parse(void const* serialized_onnx_model,
size_t serialized_onnx_model_size,
const char* model_path = nullptr)
= 0;
/** \brief Parse an onnx model file, can be a binary protobuf or a text onnx model
* calls parse method inside.
*
* \param File name
* \param Verbosity Level
*
* \return true if the model was parsed successfully
*
*/
virtual bool parseFromFile(const char* onnxModelFile, int verbosity) = 0;
/** \brief Check whether TensorRT supports a particular ONNX model
*
* \param serialized_onnx_model Pointer to the serialized ONNX model
* \param serialized_onnx_model_size Size of the serialized ONNX model
* in bytes
* \param sub_graph_collection Container to hold supported subgraphs
* \param model_path Absolute path to the model file for loading external weights if required
* \return true if the model is supported
*/
virtual bool supportsModel(void const* serialized_onnx_model,
size_t serialized_onnx_model_size,
SubGraphCollection_t& sub_graph_collection,
const char* model_path = nullptr)
= 0;
/** \brief Parse a serialized ONNX model into the TensorRT network
* with consideration of user provided weights
*
* \param serialized_onnx_model Pointer to the serialized ONNX model
* \param serialized_onnx_model_size Size of the serialized ONNX model
* in bytes
* \param weight_count number of user provided weights
* \param weight_descriptors pointer to user provided weight array
* \return true if the model was parsed successfully
* \see getNbErrors() getError()
*/
virtual bool parseWithWeightDescriptors(
void const* serialized_onnx_model, size_t serialized_onnx_model_size,
uint32_t weight_count,
onnxTensorDescriptorV1 const* weight_descriptors)
= 0;
/** \brief Returns whether the specified operator may be supported by the
* parser.
*
* Note that a result of true does not guarantee that the operator will be
* supported in all cases (i.e., this function may return false-positives).
*
* \param op_name The name of the ONNX operator to check for support
*/
virtual bool supportsOperator(const char* op_name) const = 0;
/** \brief destroy this object
*/
virtual void destroy() = 0;
/** \brief Get the number of errors that occurred during prior calls to
* \p parse
*
* \see getError() clearErrors() IParserError
*/
virtual int getNbErrors() const = 0;
/** \brief Get an error that occurred during prior calls to \p parse
*
* \see getNbErrors() clearErrors() IParserError
*/
virtual IParserError const* getError(int index) const = 0;
/** \brief Clear errors from prior calls to \p parse
*
* \see getNbErrors() getError() IParserError
*/
virtual void clearErrors() = 0;
/** \brief Get description of all ONNX weights that can be refitted.
*
* \param weightsNames Where to write the weight names to
* \param layerNames Where to write the layer names to
* \param roles Where to write the roles to
*
* \return The number of weights from the ONNX model that can be refitted
*
* If weightNames or layerNames != nullptr, each written pointer points to a string owned by
* the parser, and becomes invalid when the parser is destroyed
*
* If the same weight is used in multiple TRT layers it will be represented as a new
* entry in weightNames with name <weightName>_x, with x being the number of times the weight
* has been used before the current layer
*/
virtual int getRefitMap(const char** weightNames, const char** layerNames, nvinfer1::WeightsRole* roles) = 0;
protected:
virtual ~IParser() {}
};
} // namespace nvonnxparser
extern "C" TENSORRTAPI void* createNvOnnxParser_INTERNAL(void* network, void* logger, int version);
extern "C" TENSORRTAPI int getNvOnnxParserVersion();
namespace nvonnxparser
{
#ifdef SWIG
inline IParser* createParser(nvinfer1::INetworkDefinition* network,
nvinfer1::ILogger* logger)
{
return static_cast<IParser*>(
createNvOnnxParser_INTERNAL(network, logger, NV_ONNX_PARSER_VERSION));
}
#endif // SWIG
namespace
{
/** \brief Create a new parser object
*
* \param network The network definition that the parser will write to
* \param logger The logger to use
* \return a new parser object or NULL if an error occurred
* \see IParser
*/
#ifdef _MSC_VER
TENSORRTAPI IParser* createParser(nvinfer1::INetworkDefinition& network,
nvinfer1::ILogger& logger)
#else
inline IParser* createParser(nvinfer1::INetworkDefinition& network,
nvinfer1::ILogger& logger)
#endif
{
return static_cast<IParser*>(
createNvOnnxParser_INTERNAL(&network, &logger, NV_ONNX_PARSER_VERSION));
}
} // namespace
} // namespace nvonnxparser
#endif // NV_ONNX_PARSER_H