Skip to content

Commit d74e0b5

Browse files
committed
fix: Improve logging, restructure casting function
- Address review comments - Improve documentation and logging messages - Restructure casting function to allow for casting of variable data types - Add casting for `at::kByte` segment block inputs as well as segment block outputs
1 parent a4c2d60 commit d74e0b5

File tree

1 file changed

+37
-17
lines changed

1 file changed

+37
-17
lines changed

core/partitioning/shape_analysis.cpp

+37-17
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ torch::jit::Node* createCastNode(
103103
SegmentedBlock& seg_block,
104104
size_t index,
105105
bool is_input,
106+
at::ScalarType dtype,
106107
std::string device,
107108
bool force_create_node = false) {
108109
auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index];
@@ -115,7 +116,7 @@ torch::jit::Node* createCastNode(
115116
value_map.insert({cast_node->inputs()[0], cast_subgraph_value});
116117
if (!is_input) {
117118
// if this value is output, we need to cast it to int32
118-
auto const_val = g->insertConstant(3);
119+
auto const_val = g->insertConstant(dtype);
119120
if (cast_node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::DeviceObjType) {
120121
value_map.insert({cast_node->inputs()[2], const_val});
121122
} else {
@@ -127,7 +128,7 @@ torch::jit::Node* createCastNode(
127128
// auto cast_node = g->prependNode(g->createClone(cast_node, env));
128129
} else {
129130
// if there is no explicit cast aten::to operation, we need to create a node
130-
auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3);
131+
auto const_type = g->insertConstant(dtype);
131132
auto const_zero = g->insertConstant(0);
132133
const_zero->setType(torch::jit::BoolType::get());
133134
auto cuda = g->insertConstant(device);
@@ -230,17 +231,28 @@ void getSegmentsOutputByRunning(
230231
// auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models
231232
if (seg_block.target() == SegmentedBlock::kTorch) {
232233
// First, check if there is Int64 input
233-
if (partitioning_info.truncate_long_and_double) {
234-
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
235-
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
236-
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
237-
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
238-
if (t == at::kLong) {
239-
// we add a cast operation to cast the type to Int64
240-
auto cast_node = createCastNode(seg_block, i, true, target_device);
241-
seg_block.g()->prependNode(cast_node);
242-
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
243-
}
234+
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
235+
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
236+
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
237+
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
238+
if (t == at::kLong && partitioning_info.truncate_long_and_double) {
239+
LOG_DEBUG(
240+
"Detected graph Long tensor input type during shape analysis, "
241+
<< "inserting aten::to cast to Long to ensure this Torch block receives "
242+
<< "a Long-type tensor input.");
243+
// we add a cast operation to cast the type to Int64
244+
auto cast_node = createCastNode(seg_block, i, true, at::kLong, target_device);
245+
seg_block.g()->prependNode(cast_node);
246+
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
247+
} else if (t == at::kByte && partitioning_info.cast_int8_inputs) {
248+
LOG_DEBUG(
249+
"Detected graph Byte tensor input type during shape analysis, "
250+
<< "inserting aten::to cast to Byte to ensure this Torch block receives "
251+
<< "a Byte-type tensor input.");
252+
// If the input has type Byte, ensure it is casted to the correct type
253+
auto cast_node = createCastNode(seg_block, i, true, at::kByte, target_device, /*force_create_node=*/true);
254+
seg_block.g()->prependNode(cast_node);
255+
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
244256
}
245257
}
246258
}
@@ -250,14 +262,22 @@ void getSegmentsOutputByRunning(
250262
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]];
251263
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
252264

253-
// If the input has type Long and truncation was requested, insert truncate
265+
// If the output has type Long and truncation was requested, insert truncate
254266
if (t == at::kLong && partitioning_info.truncate_long_and_double) {
255-
auto cast_node = createCastNode(seg_block, i, false, target_device);
267+
LOG_DEBUG(
268+
"Detected graph Long tensor output type during shape analysis, "
269+
<< "inserting aten::to cast to Int to ensure the subsequent TensorRT block "
270+
<< "receives an Int-type tensor input.");
271+
auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device);
256272
seg_block.g()->appendNode(cast_node);
257273
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
258274
} else if (t == at::kByte && partitioning_info.cast_int8_inputs) {
259-
// If the input has type Byte and truncation was requested, insert Integer cast
260-
auto cast_node = createCastNode(seg_block, i, false, target_device, /*force_create_node=*/true);
275+
LOG_DEBUG(
276+
"Detected graph Byte tensor output type during shape analysis, "
277+
<< "inserting aten::to cast to Int to ensure the subsequent TensorRT block "
278+
<< "receives an Int-type tensor input.");
279+
// If the output has type Byte and casting was requested, insert Integer cast
280+
auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device, /*force_create_node=*/true);
261281
seg_block.g()->appendNode(cast_node);
262282
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
263283
}

0 commit comments

Comments
 (0)