Skip to content

Commit 596b58e

Browse files
ivangarcia44Ivan Garcia
and
Ivan Garcia
authored
Add support for transposed grouped convolution in torch to linalg lowering (#4056)
The conversion of the convolutiong torch operation to linalg currently works for grouped convolution (number of groups > 1) and transposed convolution, but the conversion failed when both are used at the same time. This change set correct this. The core of the changes are in the Linear.cpp. In transposed grouped convolution, the output filters is the one divided by the groups in the weights, not the input channel (see the "Variables" section in both links below for details). This was one of the fixes. https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html The other issue was that the weights expansion had to happen before the Channel/Filter dimension permutation/flip. This is because the expansion deals with adjacent dimensions, but in the final weights tensor the group and the input channel are not going to be adjacent. Once the dimensions are flipped, the expansion operation can't generate the expected dimension format. See the comment in the code for details. @rsuderman @vivekkhandelwal1 @zjgarvey @penguin-wwy @ubfx @sahas3 @dixinzhou @rafaelubalmw --------- Co-authored-by: Ivan Garcia <[email protected]>
1 parent 96da98b commit 596b58e

File tree

4 files changed

+185
-31
lines changed

4 files changed

+185
-31
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 77 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,55 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
955955
if (isa<mlir::IntegerType>(inputDTy))
956956
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
957957
}
958+
959+
// The expandWeight lambda function below is used to expand the group
960+
// dimension. For the normal case the group dimension is expanded out
961+
// of the output filter dimension:
962+
// expand F,C,H,W -> G,F/G,C,H,W
963+
//
964+
// Note that the group dimension has to be the first dimension. For the
965+
// transposed convolution case, the group convolution is extracted out
966+
// of the input channel dimension. But note that the input channel
967+
// dimension is interchanged with the output filter dimension (due to
968+
// the transposed operation). Because of this the group and input
969+
// channel dimensions will not be adjacent and the expand_shape
970+
// operation will not work.
971+
//
972+
// For this reason, in the transposed convolution case the expandWeight
973+
// lambda needs to be executed before this dimension flipping by doing
974+
// these two steps:
975+
//
976+
// Expansion: C,F,H,W -> G,C/G,F,H,W
977+
//
978+
// Dimension interchange: G,C/G,F,H,W -> G,F,C/G,H,W
979+
//
980+
auto expandWeight = [&](Value tensor) {
981+
auto inType = cast<RankedTensorType>(tensor.getType());
982+
auto inShape = makeShapeTorchCompatible(inType.getShape());
983+
984+
SmallVector<int64_t> outShape{numGroups,
985+
(inShape[0] == kUnknownSize
986+
? kUnknownSize
987+
: (inShape[0] / numGroups)),
988+
inShape[1]};
989+
outShape.append(inShape.begin() + 2, inShape.end());
990+
991+
SmallVector<ReassociationIndices> indices{};
992+
int currIndex = 0;
993+
indices.push_back({0, 1});
994+
currIndex += 2;
995+
for (int i = currIndex; i <= (long)inShape.size(); i++)
996+
indices.push_back({i});
997+
998+
auto retType = inType.clone(makeShapeLLVMCompatible(outShape));
999+
return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor,
1000+
indices);
1001+
};
1002+
9581003
if (transposed) {
1004+
bool isGroupedConv = numGroups > 1;
1005+
weight = isGroupedConv ? expandWeight(weight) : weight;
1006+
9591007
Value c0 =
9601008
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
9611009
Value c1 =
@@ -965,25 +1013,40 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
9651013

9661014
// Transpose and flip weight
9671015
SmallVector<Value> weightInitDims = getTensorSizes(rewriter, loc, weight);
968-
std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1);
969-
outDims[1] = weightInitDims[0];
1016+
if (isGroupedConv) {
1017+
// We need to skip the first dimension (group) in this case, also the
1018+
// output dimension needs to consider the number of groups.
1019+
std::iter_swap(weightInitDims.begin() + 1, weightInitDims.begin() + 2);
1020+
auto numGroupsVal =
1021+
rewriter.create<mlir::arith::ConstantIndexOp>(loc, numGroups);
1022+
outDims[1] = rewriter.createOrFold<mlir::arith::MulIOp>(
1023+
loc, weightInitDims[1], numGroupsVal);
1024+
} else {
1025+
std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1);
1026+
outDims[1] = weightInitDims[0];
1027+
}
1028+
auto weightRank = weightInitDims.size();
9701029
Value weightInitTensor =
9711030
createZeroInitTensor(rewriter, loc, weightInitDims, weightDTy);
9721031
SmallVector<utils::IteratorType> iteratorTypes(
973-
inRank, utils::IteratorType::parallel);
1032+
weightRank, utils::IteratorType::parallel);
9741033
SmallVector<AffineMap> indexingMaps{
975-
AffineMap::getMultiDimIdentityMap(inRank, context)};
1034+
AffineMap::getMultiDimIdentityMap(weightRank, context)};
9761035
weight = rewriter
9771036
.create<linalg::GenericOp>(
9781037
loc, weightInitTensor.getType(), ValueRange{},
9791038
weightInitTensor, indexingMaps, iteratorTypes,
9801039
[&](OpBuilder &b, Location loc, ValueRange args) {
9811040
SmallVector<Value> indices;
982-
for (size_t i = 0; i < inRank; i++)
1041+
for (size_t i = 0; i < weightRank; i++)
9831042
indices.push_back(b.create<linalg::IndexOp>(loc, i));
984-
std::iter_swap(indices.begin(), indices.begin() + 1);
985-
// Flip only the spatial dimensions (from 2 to inRank)
986-
for (size_t flipDim = 2; flipDim < inRank; flipDim++) {
1043+
auto fcIdxSwapOffset = isGroupedConv ? 1 : 0;
1044+
std::iter_swap(indices.begin() + fcIdxSwapOffset,
1045+
indices.begin() + fcIdxSwapOffset + 1);
1046+
// Flip only the spatial dimensions (from 2 to
1047+
// weightRank)
1048+
for (size_t flipDim = fcIdxSwapOffset + 2;
1049+
flipDim < weightRank; flipDim++) {
9871050
indices[flipDim] = b.create<arith::SubIOp>(
9881051
loc,
9891052
b.create<arith::SubIOp>(
@@ -1373,43 +1436,26 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
13731436
indices);
13741437
};
13751438

1376-
// expand F,C,H,W -> G,F/G,C,H,W
1377-
auto expandWeight = [&](Value tensor) {
1378-
auto inType = cast<RankedTensorType>(tensor.getType());
1379-
auto inShape = makeShapeTorchCompatible(inType.getShape());
1380-
1381-
SmallVector<int64_t> outShape{
1382-
numGroups,
1383-
(inShape[0] == kUnknownSize ? kUnknownSize : inShape[0] / numGroups)};
1384-
outShape.append(inShape.begin() + 1, inShape.end());
1385-
1386-
SmallVector<ReassociationIndices> indices{{0, 1}};
1387-
for (auto i = 2; i <= (long)inShape.size(); i++)
1388-
indices.push_back({i});
1389-
1390-
auto retType = inType.clone(makeShapeLLVMCompatible(outShape));
1391-
return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor,
1392-
indices);
1393-
};
1394-
13951439
Value paddedInputExpanded = expandGroups(paddedInput, 1);
1396-
Value weightExpanded = expandWeight(weight);
1440+
// If we have a transposed convolution, this needs to be handled before
1441+
// dimension permutation. See comments in the expandWeight lambda definition
1442+
// for details.
1443+
weight = transposed ? weight : expandWeight(weight);
13971444
auto expandOutputTensor = expandGroups(outputTensor, 1);
13981445

13991446
// TODO: add 1D and 3D case
14001447
if (!inputZp) {
14011448
conv = rewriter
14021449
.create<linalg::Conv2DNgchwGfchwOp>(
14031450
loc, expandOutputTensor.getResultType(),
1404-
ValueRange{paddedInputExpanded, weightExpanded},
1451+
ValueRange{paddedInputExpanded, weight},
14051452
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
14061453
.getResult(0);
14071454
} else {
14081455
conv = rewriter
14091456
.create<linalg::Conv2DNgchwGfchwQOp>(
14101457
loc, expandOutputTensor.getResultType(),
1411-
ValueRange{paddedInputExpanded, weightExpanded, inputZp,
1412-
weightZp},
1458+
ValueRange{paddedInputExpanded, weight, inputZp, weightZp},
14131459
expandOutputTensor.getResult(), stridesAttr, dilationAttr)
14141460
.getResult(0);
14151461
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3537,6 +3537,7 @@
35373537
"ConvolutionModule2DTransposeStridedStatic_basic",
35383538
"ConvolutionModule2DTransposeStrided_basic",
35393539
"ConvolutionModule2DTranspose_basic",
3540+
"ConvolutionModule2DGroupedTranspose_basic",
35403541
"CumsumInputDtypeInt32Module_basic",
35413542
"CumsumModule_basic",
35423543
"CumsumStaticModule_basic",
@@ -4113,6 +4114,7 @@
41134114
"ConvolutionModule2DTransposeStridedStatic_basic",
41144115
"ConvolutionModule2DTransposeStrided_basic",
41154116
"ConvolutionModule2DTranspose_basic",
4117+
"ConvolutionModule2DGroupedTranspose_basic",
41164118
"CopyModule_basic",
41174119
"CopyWithDifferentDTypesAndSizesModule_basic",
41184120
"CopyWithDifferentDTypesModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,3 +1725,35 @@ def DeformConv2D_basic(module, tu: TestUtils):
17251725
offset = tu.rand(N, offset_dim1, Hout, Wout)
17261726
weight = tu.rand(Cout, Cin, Hker, Wker)
17271727
module.forward(input, offset, weight)
1728+
1729+
1730+
class ConvolutionModule2DGroupedTranspose(torch.nn.Module):
1731+
def __init__(self):
1732+
super().__init__()
1733+
1734+
@export
1735+
@annotate_args(
1736+
[
1737+
None,
1738+
([1, 2, 5, 7], torch.float32, True),
1739+
([2, 2, 3, 3], torch.float32, True),
1740+
([4], torch.float32, True),
1741+
]
1742+
)
1743+
def forward(self, inputVec, weight, bias):
1744+
return torch.ops.aten.convolution(
1745+
inputVec,
1746+
weight,
1747+
bias=bias,
1748+
stride=[2, 2],
1749+
padding=[1, 1],
1750+
dilation=[1, 1],
1751+
transposed=True,
1752+
output_padding=[0, 0],
1753+
groups=2,
1754+
)
1755+
1756+
1757+
@register_test_case(module_factory=lambda: ConvolutionModule2DGroupedTranspose())
1758+
def ConvolutionModule2DGroupedTranspose_basic(module, tu: TestUtils):
1759+
module.forward(tu.rand(1, 2, 5, 7), tu.rand(2, 2, 3, 3), tu.rand(4))

test/Conversion/TorchToLinalg/convolution.mlir

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,77 @@ func.func @conv_broadcast(%arg0: !torch.vtensor<[1,80,3000],f32>, %arg1: !torch.
7676
%2 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %0, %0, %false, %1, %int1 : !torch.vtensor<[1,80,3000],f32>, !torch.vtensor<[1024,80,3],f32>, !torch.vtensor<[1024],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1024,3000],f32>
7777
return %2 : !torch.vtensor<[1,1024,3000],f32>
7878
}
79+
80+
// CHECK-LABEL: func.func @transposedConv2D(
81+
// CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
82+
// CHECK: = linalg.generic
83+
// CHECK-SAME: outs(%[[BROADCASTED_WEIGHTS_INIT:.*]] : tensor<4x2x3x3xf32>) {
84+
// CHECK: %[[WEIGHTS:.*]] = tensor.extract
85+
// CHECK-SAME: : tensor<2x4x3x3xf32>
86+
// CHECK-NEXT: linalg.yield %[[BROADCASTED_WEIGHTS:.*]] : f32
87+
// CHECK-NEXT: } -> tensor<4x2x3x3xf32>
88+
// CHECK: %[[BROADCASTED_BIAS:.*]] = linalg.broadcast ins(%[[BIAS:.*]] : tensor<4xf32>) outs(%[[BROADCASTED_BIAS_INIT:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
89+
// CHECK: %[[CONV_RESULT:.*]] = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
90+
// CHECK-SAME: ins(%[[INPUT_TENSOR_ADAPTED:.*]], %[[BROADCASTED_WEIGHTS:.*]] : tensor<1x2x13x17xf32>, tensor<4x2x3x3xf32>) outs(%[[BROADCASTED_BIAS:.*]] : tensor<1x4x11x15xf32>) -> tensor<1x4x11x15xf32>
91+
// CHECK-NEXT: %[[OUTPUT_TENSOR_DYN:.*]] = tensor.cast %[[CONV_RESULT:.*]] : tensor<1x4x11x15xf32> to tensor<1x4x?x?xf32>
92+
// CHECK-NEXT: %[[OUTPUT_TENSOR:.*]] = tensor.cast %[[OUTPUT_TENSOR_DYN:.*]] : tensor<1x4x?x?xf32> to tensor<1x4x10x14xf32>
93+
func.func @transposedConv2D(%arg0: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32> attributes {torch.assume_strict_symbolic_shapes} {
94+
%int0 = torch.constant.int 0
95+
%true = torch.constant.bool true
96+
%int1 = torch.constant.int 1
97+
%int2 = torch.constant.int 2
98+
%0 = torch.vtensor.literal(dense_resource<torch_tensor_2_4_3_3_torch.float32> : tensor<2x4x3x3xf32>) : !torch.vtensor<[2,4,3,3],f32>
99+
%1 = torch.vtensor.literal(dense_resource<torch_tensor_4_torch.float32> : tensor<4xf32>) : !torch.vtensor<[4],f32>
100+
%2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
101+
%3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
102+
%4 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
103+
%5 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
104+
%6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %true, %5, %int1 : !torch.vtensor<[1,2,5,7],f32>, !torch.vtensor<[2,4,3,3],f32>, !torch.vtensor<[4],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,10,14],f32>
105+
return %6 : !torch.vtensor<[1,4,10,14],f32>
106+
}
107+
108+
// CHECK-LABEL: func.func @groupedConvolution2D(
109+
// CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[1,4,5,7],f32>) -> !torch.vtensor<[1,4,5,7],f32>
110+
// CHECK: %[[BROADCASTED_BIAS:.*]] = linalg.broadcast ins(%[[BIAS:.*]] : tensor<4xf32>) outs(%[[BROADCASTED_BIAS_INIT:.*]] : tensor<1x4x5x7xf32>) dimensions = [0, 2, 3]
111+
// CHECK: %[[CONV_RESULT:.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
112+
// CHECK-SAME: ins(%[[INPUT_TENSOR_ADAPTED:.*]], %[[BROADCASTED_WEIGHTS:.*]] : tensor<1x2x2x7x9xf32>, tensor<2x2x2x3x3xf32>) outs(%[[BROADCASTED_BIAS:.*]] : tensor<1x2x2x5x7xf32>) -> tensor<1x2x2x5x7xf32>
113+
// CHECK-NEXT: %[[OUTPUT_TENSOR:.*]] = tensor.collapse_shape
114+
// CHECK-SAME: tensor<1x2x2x5x7xf32> into tensor<1x4x5x7xf32>
115+
func.func @groupedConvolution2D(%arg0: !torch.vtensor<[1,4,5,7],f32>) -> !torch.vtensor<[1,4,5,7],f32> attributes {torch.assume_strict_symbolic_shapes} {
116+
%int0 = torch.constant.int 0
117+
%false = torch.constant.bool false
118+
%int1 = torch.constant.int 1
119+
%int2 = torch.constant.int 2
120+
%0 = torch.vtensor.literal(dense_resource<torch_tensor_4_2_3_3_torch.float32> : tensor<4x2x3x3xf32>) : !torch.vtensor<[4,2,3,3],f32>
121+
%1 = torch.vtensor.literal(dense_resource<torch_tensor_4_torch.float32> : tensor<4xf32>) : !torch.vtensor<[4],f32>
122+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
123+
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
124+
%4 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
125+
%5 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
126+
%6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %false, %5, %int2 : !torch.vtensor<[1,4,5,7],f32>, !torch.vtensor<[4,2,3,3],f32>, !torch.vtensor<[4],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,5,7],f32>
127+
return %6 : !torch.vtensor<[1,4,5,7],f32>
128+
}
129+
130+
// CHECK-LABEL: func.func @transposedGroupedConvolution2D(
131+
// CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32>
132+
// CHECK: %[[BROADCASTED_BIAS:.*]] = linalg.broadcast ins(%[[BIAS:.*]] : tensor<4xf32>) outs(%[[BROADCASTED_BIAS_INIT:.*]] : tensor<1x4x11x15xf32>) dimensions = [0, 2, 3]
133+
// CHECK: %[[CONV_RESULT:.*]] = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
134+
// CHECK-SAME: ins(%[[INPUT_TENSOR_ADAPTED:.*]], %[[BROADCASTED_WEIGHTS:.*]] : tensor<1x2x1x13x17xf32>, tensor<2x2x1x3x3xf32>) outs(%[[BROADCASTED_BIAS:.*]] : tensor<1x2x2x11x15xf32>) -> tensor<1x2x2x11x15xf32>
135+
// CHECK-NEXT: %[[COLLAPSED_TENSOR:.*]] = tensor.collapse_shape
136+
// CHECK-SAME: tensor<1x2x2x11x15xf32> into tensor<1x4x11x15xf32>
137+
// CHECK-NEXT: %[[OUTPUT_TENSOR_DYN:.*]] = tensor.cast %[[COLLAPSED_TENSOR:.*]] : tensor<1x4x11x15xf32> to tensor<1x4x?x?xf32>
138+
// CHECK-NEXT: %[[OUTPUT_TENSOR:.*]] = tensor.cast %[[OUTPUT_TENSOR_DYN:.*]] : tensor<1x4x?x?xf32> to tensor<1x4x10x14xf32>
139+
func.func @transposedGroupedConvolution2D(%arg0: !torch.vtensor<[1,2,5,7],f32>) -> !torch.vtensor<[1,4,10,14],f32> attributes {torch.assume_strict_symbolic_shapes} {
140+
%int0 = torch.constant.int 0
141+
%true = torch.constant.bool true
142+
%int1 = torch.constant.int 1
143+
%int2 = torch.constant.int 2
144+
%0 = torch.vtensor.literal(dense_resource<torch_tensor_2_2_3_3_torch.float32> : tensor<2x2x3x3xf32>) : !torch.vtensor<[2,2,3,3],f32>
145+
%1 = torch.vtensor.literal(dense_resource<torch_tensor_4_torch.float32> : tensor<4xf32>) : !torch.vtensor<[4],f32>
146+
%2 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
147+
%3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
148+
%4 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
149+
%5 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
150+
%6 = torch.aten.convolution %arg0, %0, %1, %2, %3, %4, %true, %5, %int2 : !torch.vtensor<[1,2,5,7],f32>, !torch.vtensor<[2,2,3,3],f32>, !torch.vtensor<[4],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,10,14],f32>
151+
return %6 : !torch.vtensor<[1,4,10,14],f32>
152+
}

0 commit comments

Comments
 (0)