Skip to content

Commit 544654f

Browse files
authored
Merge pull request #1549 from gs-olive/gpt_2_bugfix
fix: Properly cast intermediate Int8 tensors to TensorRT Engines in Fallback
2 parents d87e4a6 + d74e0b5 commit 544654f

File tree

6 files changed

+138
-12
lines changed

6 files changed

+138
-12
lines changed

core/partitioning/partitioninginfo/PartitioningInfo.h

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ struct PartitioningInfo {
1717
std::vector<std::string> forced_fallback_operators;
1818
bool truncate_long_and_double;
1919
ir::Device target_device;
20+
bool cast_int8_inputs = false;
2021

2122
std::string getGPUDeviceString() const {
2223
return "cuda:" + std::to_string(target_device.gpu_id);

core/partitioning/shape_analysis.cpp

+56-12
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,24 @@ torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
9999
return nullptr;
100100
}
101101

102-
torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input, std::string device) {
102+
torch::jit::Node* createCastNode(
103+
SegmentedBlock& seg_block,
104+
size_t index,
105+
bool is_input,
106+
at::ScalarType dtype,
107+
std::string device,
108+
bool force_create_node = false) {
103109
auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index];
104110
auto cast_subgraph_value = is_input ? seg_block.inputs()[index] : seg_block.outputs()[index];
105111
torch::jit::Node* cast_node = getUpstreamCastNode(cast_raw_value);
106112
auto g = seg_block.g();
107113
// if we can find upstream aten::to node, we use it's parameters for creating new cast node
108-
if (cast_node) {
114+
if (cast_node && !force_create_node) {
109115
std::unordered_map<torch::jit::Value*, torch::jit::Value*> value_map;
110116
value_map.insert({cast_node->inputs()[0], cast_subgraph_value});
111117
if (!is_input) {
112118
// if this value is output, we need to cast it to int32
113-
auto const_val = g->insertConstant(3);
119+
auto const_val = g->insertConstant(dtype);
114120
if (cast_node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::DeviceObjType) {
115121
value_map.insert({cast_node->inputs()[2], const_val});
116122
} else {
@@ -122,7 +128,7 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i
122128
// auto cast_node = g->prependNode(g->createClone(cast_node, env));
123129
} else {
124130
// if there is no explicit cast aten::to operation, we need to create a node
125-
auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3);
131+
auto const_type = g->insertConstant(dtype);
126132
auto const_zero = g->insertConstant(0);
127133
const_zero->setType(torch::jit::BoolType::get());
128134
auto cuda = g->insertConstant(device);
@@ -222,27 +228,56 @@ void getSegmentsOutputByRunning(
222228

223229
auto target_device = partitioning_info.getGPUDeviceString();
224230

225-
// auto int64 <=> int32 conversion
226-
if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) {
231+
// auto int64 <=> int32 conversion + int8 <=> int32 conversion for non-quantized models
232+
if (seg_block.target() == SegmentedBlock::kTorch) {
227233
// First, check if there is Int64 input
228234
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
229235
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
230236
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
231237
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
232-
if (t == at::kLong) {
238+
if (t == at::kLong && partitioning_info.truncate_long_and_double) {
239+
LOG_DEBUG(
240+
"Detected graph Long tensor input type during shape analysis, "
241+
<< "inserting aten::to cast to Long to ensure this Torch block receives "
242+
<< "a Long-type tensor input.");
233243
// we add a cast operation to cast the type to Int64
234-
auto cast_node = createCastNode(seg_block, i, true, target_device);
244+
auto cast_node = createCastNode(seg_block, i, true, at::kLong, target_device);
245+
seg_block.g()->prependNode(cast_node);
246+
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
247+
} else if (t == at::kByte && partitioning_info.cast_int8_inputs) {
248+
LOG_DEBUG(
249+
"Detected graph Byte tensor input type during shape analysis, "
250+
<< "inserting aten::to cast to Byte to ensure this Torch block receives "
251+
<< "a Byte-type tensor input.");
252+
// If the input has type Byte, ensure it is casted to the correct type
253+
auto cast_node = createCastNode(seg_block, i, true, at::kByte, target_device, /*force_create_node=*/true);
235254
seg_block.g()->prependNode(cast_node);
236255
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
237256
}
238257
}
239258
}
259+
240260
for (size_t i = 0; i < seg_block.outputs().size(); ++i) {
241261
if (ivalues_maps[seg_block.raw_outputs()[i]].isTensor()) {
242262
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]];
243263
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
244-
if (t == at::kLong) {
245-
auto cast_node = createCastNode(seg_block, i, false, target_device);
264+
265+
// If the output has type Long and truncation was requested, insert truncate
266+
if (t == at::kLong && partitioning_info.truncate_long_and_double) {
267+
LOG_DEBUG(
268+
"Detected graph Long tensor output type during shape analysis, "
269+
<< "inserting aten::to cast to Int to ensure the subsequent TensorRT block "
270+
<< "receives an Int-type tensor input.");
271+
auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device);
272+
seg_block.g()->appendNode(cast_node);
273+
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
274+
} else if (t == at::kByte && partitioning_info.cast_int8_inputs) {
275+
LOG_DEBUG(
276+
"Detected graph Byte tensor output type during shape analysis, "
277+
<< "inserting aten::to cast to Int to ensure the subsequent TensorRT block "
278+
<< "receives an Int-type tensor input.");
279+
// If the output has type Byte and casting was requested, insert Integer cast
280+
auto cast_node = createCastNode(seg_block, i, false, at::kInt, target_device, /*force_create_node=*/true);
246281
seg_block.g()->appendNode(cast_node);
247282
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
248283
}
@@ -254,11 +289,13 @@ void getSegmentsOutputByRunning(
254289
std::vector<std::vector<int64_t>> input_shapes;
255290
std::vector<at::ScalarType> input_types;
256291
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
257-
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
292+
auto current_input = seg_block.raw_inputs()[i];
293+
294+
if (ivalues_maps[current_input].isTensor()) {
258295
// set the input_shape and data_type
259296
// we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
260297
// shape inference
261-
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
298+
auto cur_ivalue = ivalues_maps[current_input];
262299
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
263300

264301
if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) {
@@ -271,10 +308,16 @@ void getSegmentsOutputByRunning(
271308
cur_ivalue = cur_ivalue.toTensor().to(at::kFloat);
272309
LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat");
273310
}
311+
274312
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype());
275313
if (dtype == c10::nullopt) {
276314
TORCHTRT_THROW_ERROR("Unsupported input data type " << cur_ivalue.toTensor().dtype());
315+
} else if (dtype && dtype.value() == nvinfer1::DataType::kINT8 && partitioning_info.cast_int8_inputs) {
316+
// Special case to ensure input IValues to TensorRT engine are not Int8 type if the
317+
// model itself is not quantized
318+
cur_ivalue = cur_ivalue.toTensor().to(at::kInt);
277319
}
320+
278321
if (cur_ivalue.toTensor().sizes().size() == 0) {
279322
// handle Scalar types, which has sizes of []
280323
input_shapes.push_back(util::toVec(util::toDims(c10::List<int64_t>({1}))));
@@ -297,6 +340,7 @@ void runShapeAnalysis(
297340
const ir::ShapeMode& shape_mode) {
298341
// register every segment's input shape, and it's running output IValues
299342
for (auto& seg_block : ctx->partitioned_blocks[block]) {
343+
LOG_GRAPH("Running shape analysis on block " << seg_block);
300344
torch::jit::ConstantPooling(seg_block.g());
301345
getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings, shape_mode);
302346
}

core/util/trt_util.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
252252
{at::kHalf, nvinfer1::DataType::kHALF},
253253
{at::kInt, nvinfer1::DataType::kINT32},
254254
{at::kChar, nvinfer1::DataType::kINT8},
255+
{at::kByte, nvinfer1::DataType::kINT8},
255256
{at::kBool, nvinfer1::DataType::kBOOL}};
256257
return at_trt_type_map;
257258
}

cpp/src/compile_spec.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,11 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
167167
internal.convert_info.engine_settings.dla_local_dram_size = external.dla_local_dram_size;
168168
internal.convert_info.engine_settings.dla_global_dram_size = external.dla_global_dram_size;
169169

170+
internal.partitioning_info.cast_int8_inputs = true;
171+
170172
if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
171173
internal.convert_info.engine_settings.enabled_precisions.end()) {
174+
internal.partitioning_info.cast_int8_inputs = false;
172175
if (external.ptq_calibrator) {
173176
internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
174177
} else {

py/torch_tensorrt/csrc/tensorrt_classes.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -300,11 +300,15 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
300300
info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
301301
}
302302

303+
info.partitioning_info.cast_int8_inputs = true;
304+
303305
if (ptq_calibrator) {
304306
info.convert_info.engine_settings.calibrator = ptq_calibrator;
307+
info.partitioning_info.cast_int8_inputs = false;
305308
} else {
306309
if (info.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
307310
info.convert_info.engine_settings.enabled_precisions.end()) {
311+
info.partitioning_info.cast_int8_inputs = false;
308312
info.lower_info.unfreeze_module = true;
309313
info.lower_info.disable_cse = true;
310314
}
@@ -313,10 +317,23 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
313317
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
314318
info.convert_info.engine_settings.refit = refit;
315319
info.convert_info.engine_settings.debug = debug;
320+
321+
// Specify + replicate device settings for phases requiring it
316322
info.convert_info.engine_settings.device.device_type = toTRTDeviceType(device.device_type);
317323
info.convert_info.engine_settings.device.gpu_id = device.gpu_id;
318324
info.convert_info.engine_settings.device.dla_core = device.dla_core;
319325
info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback;
326+
327+
info.lower_info.target_device.device_type = toTRTDeviceType(device.device_type);
328+
info.lower_info.target_device.gpu_id = device.gpu_id;
329+
info.lower_info.target_device.dla_core = device.dla_core;
330+
info.lower_info.target_device.allow_gpu_fallback = device.allow_gpu_fallback;
331+
332+
info.partitioning_info.target_device.device_type = toTRTDeviceType(device.device_type);
333+
info.partitioning_info.target_device.gpu_id = device.gpu_id;
334+
info.partitioning_info.target_device.dla_core = device.dla_core;
335+
info.partitioning_info.target_device.allow_gpu_fallback = device.allow_gpu_fallback;
336+
320337
info.partitioning_info.enabled = torch_fallback.enabled;
321338
info.partitioning_info.min_block_size = torch_fallback.min_block_size;
322339
info.partitioning_info.forced_fallback_operators = torch_fallback.forced_fallback_operators;

tests/core/partitioning/test_type_auto_conversion.cpp

+60
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,63 @@ TEST(Partitioning, ImplicitAutoConversionCorrectly) {
107107
}
108108
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2));
109109
}
110+
111+
TEST(Partitioning, ExplicitNodeAutoInt8ConversionCorrectly) {
112+
const auto graph = R"IR(
113+
graph(%x.1 : Tensor,
114+
%y.1 : Tensor):
115+
116+
%26 : int = prim::Constant[value=1]()
117+
%21 : bool = prim::Constant[value=0]()
118+
%60 : Device = prim::Constant[value="cuda"]()
119+
%14 : NoneType = prim::Constant()
120+
%3 : int = prim::Constant[value=5]()
121+
%19 : int = prim::Constant[value=0]()
122+
%29 : int = prim::Constant[value=2]()
123+
%13 : int[] = prim::ListConstruct(%3, %3)
124+
%k_.1 : Tensor = aten::ones(%13, %19, %14, %60, %14)
125+
%20 : int[] = prim::ListConstruct(%19)
126+
%k.1 : Tensor = aten::sum(%k_.1, %20, %21, %14)
127+
%x.5 : Tensor = aten::add_(%x.1, %y.1, %26)
128+
%31 : Tensor = aten::mul(%y.1, %29)
129+
%x.9 : Tensor = aten::add_(%x.5, %31, %26)
130+
%x.13 : Tensor = aten::add_(%x.9, %k.1, %26)
131+
%x.17 : Tensor = aten::sub_(%x.13, %k.1, %26)
132+
%x.21 : Tensor = aten::add_(%x.17, %k.1, %26)
133+
%x.25 : Tensor = aten::sub_(%x.21, %k.1, %26)
134+
135+
return (%x.25))IR";
136+
137+
auto g = std::make_shared<torch::jit::Graph>();
138+
torch::jit::parseIR(graph, g.get(), true);
139+
140+
torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
141+
partitioning_info.enabled = true;
142+
partitioning_info.cast_int8_inputs = true;
143+
partitioning_info.forced_fallback_operators = {"aten::ones"};
144+
partitioning_info.truncate_long_and_double = true;
145+
std::vector<torch_tensorrt::core::ir::Input> inputs;
146+
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));
147+
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));
148+
149+
std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map;
150+
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
151+
inputs_map.insert({g->inputs()[0], {inputs[0]}});
152+
input_types.insert({g->inputs()[0], {{at::kFloat}}});
153+
inputs_map.insert({g->inputs()[1], {inputs[1]}});
154+
input_types.insert({g->inputs()[1], {{at::kInt}}});
155+
156+
partitioning_info.collection_input_spec_map = inputs_map;
157+
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
158+
ctx.input_types_map = input_types;
159+
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
160+
torch_tensorrt::core::partitioning::partition(&ctx);
161+
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;
162+
163+
for (auto& seg_block : segmented_blocks) {
164+
LOG_DEBUG(seg_block << " cur seg block");
165+
}
166+
167+
// Seeking 1 inserted aten::to converting Byte to Int (%k_.1 is a Byte Tensor)
168+
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[0], 1));
169+
}

0 commit comments

Comments
 (0)