@@ -103,6 +103,7 @@ torch::jit::Node* createCastNode(
103
103
SegmentedBlock& seg_block,
104
104
size_t index,
105
105
bool is_input,
106
+ at::ScalarType dtype,
106
107
std::string device,
107
108
bool force_create_node = false ) {
108
109
auto cast_raw_value = is_input ? seg_block.raw_inputs ()[index ] : seg_block.raw_outputs ()[index ];
@@ -115,7 +116,7 @@ torch::jit::Node* createCastNode(
115
116
value_map.insert ({cast_node->inputs ()[0 ], cast_subgraph_value});
116
117
if (!is_input) {
117
118
// if this value is output, we need to cast it to int32
118
- auto const_val = g->insertConstant (3 );
119
+ auto const_val = g->insertConstant (dtype );
119
120
if (cast_node->inputs ()[1 ]->node ()->output ()->type ()->kind () == torch::jit::TypeKind::DeviceObjType) {
120
121
value_map.insert ({cast_node->inputs ()[2 ], const_val});
121
122
} else {
@@ -127,7 +128,7 @@ torch::jit::Node* createCastNode(
127
128
// auto cast_node = g->prependNode(g->createClone(cast_node, env));
128
129
} else {
129
130
// if there is no explicit cast aten::to operation, we need to create a node
130
- auto const_type = is_input ? g->insertConstant (4 ) : g-> insertConstant ( 3 );
131
+ auto const_type = g->insertConstant (dtype );
131
132
auto const_zero = g->insertConstant (0 );
132
133
const_zero->setType (torch::jit::BoolType::get ());
133
134
auto cuda = g->insertConstant (device);
@@ -230,17 +231,28 @@ void getSegmentsOutputByRunning(
230
231
// auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models
231
232
if (seg_block.target () == SegmentedBlock::kTorch ) {
232
233
// First, check if there is Int64 input
233
- if (partitioning_info.truncate_long_and_double ) {
234
- for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
235
- if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
236
- auto cur_ivalue = ivalues_maps[seg_block.raw_inputs ()[i]];
237
- at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
238
- if (t == at::kLong ) {
239
- // we add a cast operation to cast the type to Int64
240
- auto cast_node = createCastNode (seg_block, i, true , target_device);
241
- seg_block.g ()->prependNode (cast_node);
242
- seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
243
- }
234
+ for (size_t i = 0 ; i < seg_block.inputs ().size (); ++i) {
235
+ if (ivalues_maps[seg_block.raw_inputs ()[i]].isTensor ()) {
236
+ auto cur_ivalue = ivalues_maps[seg_block.raw_inputs ()[i]];
237
+ at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
238
+ if (t == at::kLong && partitioning_info.truncate_long_and_double ) {
239
+ LOG_DEBUG (
240
+ " Detected graph Long tensor input type during shape analysis, "
241
+ << " inserting aten::to cast to Long to ensure this Torch block receives "
242
+ << " a Long-type tensor input." );
243
+ // we add a cast operation to cast the type to Int64
244
+ auto cast_node = createCastNode (seg_block, i, true , at::kLong , target_device);
245
+ seg_block.g ()->prependNode (cast_node);
246
+ seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
247
+ } else if (t == at::kByte && partitioning_info.cast_int8_inputs ) {
248
+ LOG_DEBUG (
249
+ " Detected graph Byte tensor input type during shape analysis, "
250
+ << " inserting aten::to cast to Byte to ensure this Torch block receives "
251
+ << " a Byte-type tensor input." );
252
+ // If the input has type Byte, ensure it is casted to the correct type
253
+ auto cast_node = createCastNode (seg_block, i, true , at::kByte , target_device, /* force_create_node=*/ true );
254
+ seg_block.g ()->prependNode (cast_node);
255
+ seg_block.inputs ()[i]->replaceAllUsesAfterNodeWith (cast_node, cast_node->outputs ()[0 ]);
244
256
}
245
257
}
246
258
}
@@ -250,14 +262,22 @@ void getSegmentsOutputByRunning(
250
262
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs ()[i]];
251
263
at::ScalarType t = cur_ivalue.toTensor ().scalar_type ();
252
264
253
- // If the input has type Long and truncation was requested, insert truncate
265
+ // If the output has type Long and truncation was requested, insert truncate
254
266
if (t == at::kLong && partitioning_info.truncate_long_and_double ) {
255
- auto cast_node = createCastNode (seg_block, i, false , target_device);
267
+ LOG_DEBUG (
268
+ " Detected graph Long tensor output type during shape analysis, "
269
+ << " inserting aten::to cast to Int to ensure the subsequent TensorRT block "
270
+ << " receives an Int-type tensor input." );
271
+ auto cast_node = createCastNode (seg_block, i, false , at::kInt , target_device);
256
272
seg_block.g ()->appendNode (cast_node);
257
273
seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
258
274
} else if (t == at::kByte && partitioning_info.cast_int8_inputs ) {
259
- // If the input has type Byte and truncation was requested, insert Integer cast
260
- auto cast_node = createCastNode (seg_block, i, false , target_device, /* force_create_node=*/ true );
275
+ LOG_DEBUG (
276
+ " Detected graph Byte tensor output type during shape analysis, "
277
+ << " inserting aten::to cast to Int to ensure the subsequent TensorRT block "
278
+ << " receives an Int-type tensor input." );
279
+ // If the output has type Byte and casting was requested, insert Integer cast
280
+ auto cast_node = createCastNode (seg_block, i, false , at::kInt , target_device, /* force_create_node=*/ true );
261
281
seg_block.g ()->appendNode (cast_node);
262
282
seg_block.g ()->block ()->replaceOutput (i, cast_node->outputs ()[0 ]);
263
283
}
0 commit comments