@@ -99,18 +99,24 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
99
99
return nullptr ;
100
100
}
101
101
102
- torch::jit::Node* createCastNode (SegmentedBlock& seg_block, size_t index, bool is_input, std::string device) {
102
+ torch::jit::Node* createCastNode (
103
+ SegmentedBlock& seg_block,
104
+ size_t index,
105
+ bool is_input,
106
+ at::ScalarType dtype,
107
+ std::string device,
108
+ bool force_create_node = false ) {
103
109
auto cast_raw_value = is_input ? seg_block.raw_inputs ()[index ] : seg_block.raw_outputs ()[index ];
104
110
auto cast_subgraph_value = is_input ? seg_block.inputs ()[index ] : seg_block.outputs ()[index ];
105
111
torch::jit::Node* cast_node = getUpstreamCastNode (cast_raw_value);
106
112
auto g = seg_block.g ();
107
113
// if we can find upstream aten::to node, we use it's parameters for creating new cast node
108
- if (cast_node) {
114
+ if (cast_node && !force_create_node ) {
109
115
std::unordered_map<torch::jit::Value*, torch::jit::Value*> value_map;
110
116
value_map.insert ({cast_node->inputs ()[0 ], cast_subgraph_value});
111
117
if (!is_input) {
112
118
// if this value is output, we need to cast it to int32
113
- auto const_val = g->insertConstant (3 );
119
+ auto const_val = g->insertConstant (dtype );
114
120
if (cast_node->inputs ()[1 ]->node ()->output ()->type ()->kind () == torch::jit::TypeKind::DeviceObjType) {
115
121
value_map.insert ({cast_node->inputs ()[2 ], const_val});
116
122
} else {
@@ -122,7 +128,7 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i
122
128
// auto cast_node = g->prependNode(g->createClone(cast_node, env));
123
129
} else {
124
130
// if there is no explicit cast aten::to operation, we need to create a node
125
- auto const_type = is_input ? g->insertConstant (4 ) : g-> insertConstant ( 3 );
131
+ auto const_type = g->insertConstant (dtype );
126
132
auto const_zero = g->insertConstant (0 );
127
133
const_zero->setType (torch::jit::BoolType::get ());
128
134
auto cuda = g->insertConstant (device);
@@ -222,27 +228,56 @@ void getSegmentsOutputByRunning(
222
228
223
229
auto target_device = partitioning_info.getGPUDeviceString ();
224
230
225
- // auto int64 <=> int32 conversion
226
- if (seg_block.target () == SegmentedBlock::kTorch && partitioning_info. truncate_long_and_double ) {
231
+ // auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models
232
+ if (seg_block.target () == SegmentedBlock::kTorch ) {
227
233
// First, check if there is Int64 input
228
234
for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
229
235
if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
230
236
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs ()[i]];
231
237
at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
232
- if (t == at::kLong ) {
238
+ if (t == at::kLong && partitioning_info.truncate_long_and_double ) {
239
+ LOG_DEBUG (
240
+ " Detected graph Long tensor input type during shape analysis, "
241
+ << " inserting aten::to cast to Long to ensure this Torch block receives "
242
+ << " a Long-type tensor input." );
233
243
// we add a cast operation to cast the type to Int64
234
- auto cast_node = createCastNode (seg_block, i, true , target_device);
244
+ auto cast_node = createCastNode (seg_block, i, true , at::kLong , target_device);
245
+ seg_block.g ()->prependNode (cast_node);
246
+ seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
247
+ } else if (t == at::kByte && partitioning_info.cast_int8_inputs ) {
248
+ LOG_DEBUG (
249
+ " Detected graph Byte tensor input type during shape analysis, "
250
+ << " inserting aten::to cast to Byte to ensure this Torch block receives "
251
+ << " a Byte-type tensor input." );
252
+ // If the input has type Byte, ensure it is casted to the correct type
253
+ auto cast_node = createCastNode (seg_block, i, true , at::kByte , target_device, /* force_create_node=*/ true );
235
254
seg_block.g ()->prependNode (cast_node);
236
255
seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
237
256
}
238
257
}
239
258
}
259
+
240
260
for (size_t i = 0 ; i < seg_block.outputs ().size (); ++i) {
241
261
if (ivalues_maps[seg_block.raw_outputs ()[i]].isTensor ()) {
242
262
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs ()[i]];
243
263
at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
244
- if (t == at::kLong ) {
245
- auto cast_node = createCastNode (seg_block, i, false , target_device);
264
+
265
+ // If the output has type Long and truncation was requested, insert truncate
266
+ if (t == at::kLong && partitioning_info.truncate_long_and_double ) {
267
+ LOG_DEBUG (
268
+ " Detected graph Long tensor output type during shape analysis, "
269
+ << " inserting aten::to cast to Int to ensure the subsequent TensorRT block "
270
+ << " receives an Int-type tensor input." );
271
+ auto cast_node = createCastNode (seg_block, i, false , at::kInt , target_device);
272
+ seg_block.g ()->appendNode (cast_node);
273
+ seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
274
+ } else if (t == at::kByte && partitioning_info.cast_int8_inputs ) {
275
+ LOG_DEBUG (
276
+ " Detected graph Byte tensor output type during shape analysis, "
277
+ << " inserting aten::to cast to Int to ensure the subsequent TensorRT block "
278
+ << " receives an Int-type tensor input." );
279
+ // If the output has type Byte and casting was requested, insert Integer cast
280
+ auto cast_node = createCastNode (seg_block, i, false , at::kInt , target_device, /* force_create_node=*/ true );
246
281
seg_block.g ()->appendNode (cast_node);
247
282
seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
248
283
}
@@ -254,11 +289,13 @@ void getSegmentsOutputByRunning(
254
289
std::vector<std::vector<int64_t >> input_shapes;
255
290
std::vector<at::ScalarType> input_types;
256
291
for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
257
- if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
292
+ auto current_input = seg_block.raw_inputs ()[i];
293
+
294
+ if (ivalues_maps[current_input].isTensor ()) {
258
295
// set the input_shape and data_type
259
296
// we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
260
297
// shape inference
261
- auto cur_ivalue = ivalues_maps[seg_block. raw_inputs ()[i] ];
298
+ auto cur_ivalue = ivalues_maps[current_input ];
262
299
at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
263
300
264
301
if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble )) {
@@ -271,10 +308,16 @@ void getSegmentsOutputByRunning(
271
308
cur_ivalue = cur_ivalue.toTensor ().to (at::kFloat );
272
309
LOG_WARNING (" Truncating graph input type from at::kDouble to at::kFloat" );
273
310
}
311
+
274
312
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType (cur_ivalue.toTensor ().dtype ());
275
313
if (dtype == c10::nullopt) {
276
314
TORCHTRT_THROW_ERROR (" Unsupported input data type " << cur_ivalue.toTensor ().dtype ());
315
+ } else if (dtype && dtype.value () == nvinfer1::DataType::kINT8 && partitioning_info.cast_int8_inputs ) {
316
+ // Special case to ensure input IValues to TensorRT engine are not Int8 type if the
317
+ // model itself is not quantized
318
+ cur_ivalue = cur_ivalue.toTensor ().to (at::kInt );
277
319
}
320
+
278
321
if (cur_ivalue.toTensor ().sizes ().size () == 0 ) {
279
322
// handle Scalar types, which has sizes of []
280
323
input_shapes.push_back (util::toVec (util::toDims (c10::List<int64_t >({1 }))));
@@ -297,6 +340,7 @@ void runShapeAnalysis(
297
340
const ir::ShapeMode& shape_mode) {
298
341
// register every segment's input shape, and it's running output IValues
299
342
for (auto & seg_block : ctx->partitioned_blocks [block]) {
343
+ LOG_GRAPH (" Running shape analysis on block " << seg_block);
300
344
torch::jit::ConstantPooling (seg_block.g ());
301
345
getSegmentsOutputByRunning (seg_block, example_tensor_map, ctx->settings , shape_mode);
302
346
}
0 commit comments