Skip to content

Commit b7ca351

Browse files
zjgarveysjarus
authored andcommitted
[TorchToLinalg] add support for quantized group conv (llvm#3341)
This addresses 7 of the model failures I'm seeing in the test suite. See [Shark-Turbine issue llvm#566](nod-ai/SHARK-ModelDev#566). Need the op ```linalg.conv_2d_ngchw_gfchw_q``` to be added upstream before merging this. See [llvm-project PR #92136 ](llvm/llvm-project#92136). A small additional expansion to operand quantization is included in this patch to address a model failure that occurs when unblocking the quantized group convolutions in one of these onnx models.
1 parent 3ff2781 commit b7ca351

File tree

4 files changed

+44
-25
lines changed

4 files changed

+44
-25
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
829829
op, "lhs and rhs of convolution must either be both int or fp");
830830
}
831831

832-
if (inputZp && weightZp && !isa<Torch::NoneType>(bias.getType())) {
832+
if (inputZp && !isa<Torch::NoneType>(bias.getType())) {
833833
auto biasDTy = cast<RankedTensorType>(bias.getType()).getElementType();
834834
if (!biasDTy.isInteger(32)) {
835835
return rewriter.notifyMatchFailure(
@@ -1123,7 +1123,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
11231123
// - grouped 1d-3d
11241124
// - grouped 1d-3d (quantized)
11251125
// - ungrouped 1d-3d
1126-
if (groupSize == 1 && !inputZp && !weightZp) {
1126+
if (groupSize == 1 && !inputZp) {
11271127
switch (numSpatialDims) {
11281128
case 1:
11291129
conv = rewriter
@@ -1164,7 +1164,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
11641164
return success();
11651165
}
11661166

1167-
if (groupSize == 1 && inputZp && weightZp) {
1167+
if (groupSize == 1 && inputZp) {
11681168
// The quantized version uses a different channel ordering so we need to
11691169
// permute the tensors in order to use the existing path. We should
11701170
// eventually directly support this channel ordering.
@@ -1224,10 +1224,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
12241224
return success();
12251225
}
12261226

1227-
if (inputZp || weightZp)
1228-
return rewriter.notifyMatchFailure(
1229-
op, "unimplemented: quantized grouped convolutions");
1230-
12311227
if (numSpatialDims != 2)
12321228
return rewriter.notifyMatchFailure(
12331229
op, "unimplemented: only 2D grouped convolution supported");
@@ -1238,7 +1234,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
12381234
auto weightShape = makeShapeTorchCompatible(
12391235
cast<RankedTensorType>(weight.getType()).getShape());
12401236
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
1241-
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) {
1237+
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1 && !inputZp) {
12421238
// Collapse weight shape
12431239
SmallVector<ReassociationIndices, 4> collapsedDims = {{0, 1}, {2}, {3}};
12441240
SmallVector<int64_t> collapsedShape{
@@ -1325,13 +1321,22 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
13251321
auto expandOutputTensor = expandGroups(outputTensor, 1);
13261322

13271323
// TODO: add 1D and 3D case
1328-
conv = rewriter
1329-
.create<linalg::Conv2DNgchwGfchwOp>(
1330-
loc, expandOutputTensor.getResultType(),
1331-
ValueRange{paddedInputExpanded, weightExpanded},
1332-
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
1333-
.getResult(0);
1334-
1324+
if (!inputZp) {
1325+
conv = rewriter
1326+
.create<linalg::Conv2DNgchwGfchwOp>(
1327+
loc, expandOutputTensor.getResultType(),
1328+
ValueRange{paddedInputExpanded, weightExpanded},
1329+
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
1330+
.getResult(0);
1331+
} else {
1332+
conv = rewriter
1333+
.create<linalg::Conv2DNgchwGfchwQOp>(
1334+
loc, expandOutputTensor.getResultType(),
1335+
ValueRange{paddedInputExpanded, weightExpanded, inputZp,
1336+
weightZp},
1337+
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
1338+
.getResult(0);
1339+
}
13351340
conv = rewriter.create<tensor::CollapseShapeOp>(
13361341
loc, outputTensor.getType(), conv,
13371342
expandOutputTensor.getReassociationIndices());

lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
378378
QuantizeOperandsPastCommutingOps<AtenConvolutionOp, 5>,
379379
QuantizeOperandsPastCommutingOps<AtenReluOp, 0>,
380380
QuantizeOperandsPastCommutingOps<AtenMatmulOp, 2>,
381-
QuantizeOperandsPastCommutingOps<AtenMmOp, 2>,
381+
QuantizeOperandsPastCommutingOps<AtenMmOp, 4>,
382382
QuantizeAccumulator<AtenMmOp>, QuantizeAccumulator<AtenMatmulOp>,
383383
QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>(
384384
context);

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@
277277
"QuantizedReluInt8_basic",
278278
"QuantizedReluUint8_basic",
279279
"Conv2dQInt8Module_basic",
280+
"Conv2dQInt8Module_grouped",
280281
"ConvTranspose2DQInt8_basic",
281282
# Dynamo not supporting conv_tbc
282283
"ConvTbcModule_basic",
@@ -373,6 +374,7 @@
373374
"ContainsIntList_False",
374375
"ContainsIntList_True",
375376
"Conv2dQInt8Module_basic",
377+
"Conv2dQInt8Module_grouped",
376378
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
377379
"ConvTbcModule_basic",
378380
"ConvTranspose2DQInt8_basic",
@@ -543,6 +545,7 @@
543545
"ContainsIntList_False",
544546
"ContainsIntList_True",
545547
"Conv2dQInt8Module_basic",
548+
"Conv2dQInt8Module_grouped",
546549
"ConvTbcModule_basic",
547550
"ConvTranspose2DQInt8_basic",
548551
"ConvolutionBackwardModule2DPadded_basic",
@@ -2147,6 +2150,7 @@
21472150
"ElementwiseBitwiseAndScalarInt32Module_basic",
21482151
"ElementwiseBitwiseAndScalarInt8Module_basic",
21492152
"Conv2dQInt8Module_basic",
2153+
"Conv2dQInt8Module_grouped",
21502154
"ConvTranspose2DQInt8_basic",
21512155
}
21522156

@@ -2298,6 +2302,7 @@
22982302
"Conv2dModule_basic",
22992303
"Conv2dNoPaddingModule_basic",
23002304
"Conv2dQInt8Module_basic",
2305+
"Conv2dQInt8Module_grouped",
23012306
"Conv2dWithPaddingDilationStrideModule_basic",
23022307
"Conv2dWithPaddingModule_basic",
23032308
"Conv3dModule_basic",
@@ -2851,6 +2856,7 @@
28512856
"ContainsIntList_True",
28522857
"Conv1dModule_basic",
28532858
"Conv2dQInt8Module_basic",
2859+
"Conv2dQInt8Module_grouped",
28542860
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
28552861
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
28562862
"Conv3dModule_basic",
@@ -3637,6 +3643,7 @@
36373643
"Conv2dModule_basic",
36383644
"Conv2dNoPaddingModule_basic",
36393645
"Conv2dQInt8Module_basic",
3646+
"Conv2dQInt8Module_grouped",
36403647
"Conv2dWithPaddingDilationStrideModule_basic",
36413648
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
36423649
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,7 +1157,8 @@ def ConvTbcModule_basic(module, tu: TestUtils):
11571157

11581158

11591159
class Conv2dQInt8Module(torch.nn.Module):
1160-
def __init__(self):
1160+
def __init__(self, groups=1):
1161+
self.groups = groups
11611162
super().__init__()
11621163

11631164
@export
@@ -1186,7 +1187,7 @@ def forward(self, inputVec, weight, bias):
11861187
stride=[1, 1],
11871188
padding=[0, 0],
11881189
dilation=[1, 1],
1189-
groups=1,
1190+
groups=self.groups,
11901191
)
11911192

11921193

@@ -1198,13 +1199,12 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
11981199
module.forward(inputVec, weight, bias)
11991200

12001201

1201-
N = 10
1202-
Cin = 5
1203-
Cout = 7
1204-
Hin = 10
1205-
Win = 8
1206-
Hker = 3
1207-
Wker = 2
1202+
@register_test_case(module_factory=lambda: Conv2dQInt8Module(groups=2))
1203+
def Conv2dQInt8Module_grouped(module, tu: TestUtils):
1204+
inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8)
1205+
weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8)
1206+
bias = torch.rand(6)
1207+
module.forward(inputVec, weight, bias)
12081208

12091209

12101210
class ConvTranspose2DQInt8Module(torch.nn.Module):
@@ -1244,6 +1244,13 @@ def forward(self, input, weight, bias):
12441244

12451245
@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
12461246
def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
1247+
N = 10
1248+
Cin = 5
1249+
Cout = 7
1250+
Hin = 10
1251+
Win = 8
1252+
Hker = 3
1253+
Wker = 2
12471254
module.forward(
12481255
tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
12491256
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),

0 commit comments

Comments
 (0)