Skip to content

Commit 8e0eafd

Browse files
[MLIR][TORCH] Add support for 1-d group convolution (#3904)
This commit adds the support for 1-d group convolution by transforming it into a 2-d group convolution which is already supported. This commit also refactors the unsqueeze and squeeze tensor utility. --------- Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent 2c72a82 commit 8e0eafd

File tree

6 files changed

+230
-93
lines changed

6 files changed

+230
-93
lines changed

include/torch-mlir/Conversion/Utils/Utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
9797
Value torchOptionalInt, Value builtinInt,
9898
Value defaultValue, Value dimSize);
9999

100+
// Helper function to unsqueeze the input tensor at given dim.
101+
// Returns the unsqueezed tensor or failure.
102+
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
103+
Value input, int64_t dim);
104+
105+
// Helper function to squeeze the input tensor at given dim.
106+
// Returns the squeezed tensor or failure.
107+
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
108+
Value input, int64_t dim);
100109
} // namespace Torch
101110
} // namespace torch
102111
} // namespace mlir

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 13 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,69 +1642,18 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {
16421642
ConversionPatternRewriter &rewriter) const override {
16431643
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
16441644
return failure();
1645-
Value input = adaptor.getSelf();
1646-
auto inputType = cast<RankedTensorType>(input.getType());
1647-
int64_t inputRank = inputType.getRank();
1648-
1649-
if (inputRank == 0) {
1650-
return rewriter.notifyMatchFailure(
1651-
op, "zero input rank should have been handled by the folder");
1652-
}
1653-
16541645
int64_t dim;
16551646
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
16561647
return rewriter.notifyMatchFailure(op, "dim must be constant");
1657-
dim = toPositiveDim(dim, inputRank);
1658-
if (!isValidDim(dim, inputRank))
1659-
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
1660-
1661-
// assert dynamic squeeze dim size == 1
1662-
if (inputType.isDynamicDim(dim)) {
1663-
Value cstDim = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), dim);
1664-
Value dimVal = rewriter.create<tensor::DimOp>(op.getLoc(), input, cstDim);
1665-
Value cstOne = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
1666-
Value cmp = rewriter.create<arith::CmpIOp>(
1667-
op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne);
1668-
rewriter.create<cf::AssertOp>(
1669-
op.getLoc(), cmp,
1670-
rewriter.getStringAttr(
1671-
"Expected dynamic squeeze dim size to be statically 1"));
1672-
}
1673-
1674-
const TypeConverter *typeConverter = getTypeConverter();
1675-
auto resultType =
1676-
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
1677-
int64_t resultRank = resultType.getRank();
16781648

1679-
// If the dim(th) dimension of operand tensor type is not statically unit,
1680-
// `aten.squeeze` will behave as an identity operation.
1681-
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
1682-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, input);
1683-
return success();
1649+
auto squeezeTensorInfo =
1650+
squeezeTensor(rewriter, op, adaptor.getSelf(), dim);
1651+
if (failed(squeezeTensorInfo)) {
1652+
return rewriter.notifyMatchFailure(op,
1653+
"cannot generate unsqueeze tensor");
16841654
}
16851655

1686-
SmallVector<ReassociationIndices> reassociationMap(resultRank);
1687-
bool alreadyCrossedSqueezedDim = false;
1688-
for (int i = 0; i != resultRank; i++) {
1689-
if (alreadyCrossedSqueezedDim) {
1690-
reassociationMap[i].push_back(i + 1);
1691-
} else {
1692-
reassociationMap[i].push_back(i);
1693-
if (dim != 0 && i != dim - 1)
1694-
continue;
1695-
1696-
alreadyCrossedSqueezedDim = true;
1697-
if (dim == 0)
1698-
reassociationMap[0].push_back(1);
1699-
if (i == dim - 1)
1700-
reassociationMap[i].push_back(dim);
1701-
}
1702-
}
1703-
// Note: In case the operand tensor type is of unit rank and is statically
1704-
// shaped with unit dimension, the `reassociationMap` will be empty and the
1705-
// input will be collapsed to a 0-D tensor.
1706-
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(op, resultType, input,
1707-
reassociationMap);
1656+
rewriter.replaceOp(op, squeezeTensorInfo.value());
17081657
return success();
17091658
}
17101659
};
@@ -1722,36 +1671,15 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
17221671
int64_t dim;
17231672
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
17241673
return rewriter.notifyMatchFailure(op, "dim must be constant");
1725-
auto inputRank =
1726-
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
1727-
dim = toPositiveDim(dim, inputRank + 1);
1728-
if (!isValidDim(dim, inputRank + 1))
1729-
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
17301674

1731-
SmallVector<ReassociationIndices> reassociationMap(inputRank);
1732-
// From the perspective of the reassociation map, the situation of
1733-
// unsqueezing before or after the last dimension is symmetrical.
1734-
// Normalize it to the "before" case.
1735-
// The 0 case is special here, since there is no last dimension to insert
1736-
// before -- we simply rely on the loop below iterating 0 times.
1737-
if (dim == inputRank && inputRank != 0)
1738-
dim = inputRank - 1;
1739-
bool alreadyCrossedExpandedDim = false;
1740-
for (int i = 0; i != inputRank; i++) {
1741-
if (alreadyCrossedExpandedDim) {
1742-
reassociationMap[i].push_back(i + 1);
1743-
} else {
1744-
reassociationMap[i].push_back(i);
1745-
if (i == dim) {
1746-
reassociationMap[i].push_back(i + 1);
1747-
alreadyCrossedExpandedDim = true;
1748-
}
1749-
}
1675+
auto unsqueezeTensorInfo =
1676+
unsqueezeTensor(rewriter, op, adaptor.getSelf(), dim);
1677+
if (failed(unsqueezeTensorInfo)) {
1678+
return rewriter.notifyMatchFailure(op,
1679+
"cannot generate unsqueeze tensor");
17501680
}
1751-
auto resultType = cast<RankedTensorType>(
1752-
getTypeConverter()->convertType(op->getResult(0).getType()));
1753-
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1754-
op, resultType, adaptor.getSelf(), reassociationMap);
1681+
1682+
rewriter.replaceOp(op, unsqueezeTensorInfo.value());
17551683
return success();
17561684
}
17571685
};

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,48 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
850850
return rewriter.notifyMatchFailure(op,
851851
"only support constant int dilations");
852852

853+
// Checks for valid group size
854+
int64_t numGroups;
855+
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
856+
return rewriter.notifyMatchFailure(op,
857+
"only constant group size supported.");
858+
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());
859+
860+
// Adding support for 1d group convolution by converting the 1d-conv to
861+
// 2d-conv.
862+
// TODO: Replace this logic with the appropriate linalg op for 1-d group
863+
// convolution once that support is added.
864+
bool is1DGroupConv = (numSpatialDims == 1 && numGroups != 1);
865+
if (is1DGroupConv) {
866+
// Unsqueezing the last dim of input and weight. Also extending the
867+
// dilation, stride, padding, and output padding lists.
868+
auto unsqueezeInputInfo =
869+
unsqueezeTensor(rewriter, op, input, /*dim=*/-1);
870+
if (failed(unsqueezeInputInfo)) {
871+
return rewriter.notifyMatchFailure(op,
872+
"cannot generate unsqueeze tensor");
873+
}
874+
input = unsqueezeInputInfo.value();
875+
876+
auto unsqueezeWeightInfo =
877+
unsqueezeTensor(rewriter, op, weight, /*dim=*/-1);
878+
if (failed(unsqueezeWeightInfo)) {
879+
return rewriter.notifyMatchFailure(op,
880+
"cannot generate unsqueeze tensor");
881+
}
882+
weight = unsqueezeWeightInfo.value();
883+
884+
Value cstZero = rewriter.create<arith::ConstantOp>(
885+
loc, rewriter.getI64IntegerAttr(0));
886+
paddingIntValues.push_back(cstZero);
887+
outputPaddingIntValues.push_back(cstZero);
888+
strideInts.push_back(1);
889+
dilationInts.push_back(1);
890+
891+
inRank++;
892+
numSpatialDims++;
893+
}
894+
853895
Value inBatch = getDimOp(rewriter, loc, input, 0);
854896
Value inChannels = getDimOp(rewriter, loc, input, 1);
855897
SmallVector<Value> inDims;
@@ -861,13 +903,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
861903
for (size_t i = 2; i < inRank; i++)
862904
weightDims.push_back(getDimOp(rewriter, loc, weight, i));
863905

864-
// Checks for valid group size
865-
int64_t numGroups;
866-
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups)))
867-
return rewriter.notifyMatchFailure(op,
868-
"only constant group size supported.");
869-
Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups());
870-
871906
auto validate = [&](Value toValidate, std::string err) {
872907
Value c0 =
873908
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
@@ -1280,13 +1315,24 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
12801315
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
12811316
resultElementType);
12821317
}
1318+
1319+
if (is1DGroupConv) {
1320+
// Squeezing the last dim of the result of conv.
1321+
auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1);
1322+
if (failed(squeezeOutputInfo)) {
1323+
return rewriter.notifyMatchFailure(op,
1324+
"cannot generate squeeze tensor");
1325+
}
1326+
conv = squeezeOutputInfo.value();
1327+
}
1328+
12831329
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
12841330
return success();
12851331
}
12861332

12871333
if (numSpatialDims != 2)
12881334
return rewriter.notifyMatchFailure(
1289-
op, "unimplemented: only 2D grouped convolution supported");
1335+
op, "unimplemented: only 1D and 2D grouped convolution supported");
12901336

12911337
// Grouped case, use the grouped conv linalg op
12921338
auto expandGroups = [&](Value tensor, size_t dim) {
@@ -1371,6 +1417,16 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
13711417
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
13721418
resultElementType);
13731419
}
1420+
1421+
if (is1DGroupConv) {
1422+
// Squeezing the last dim of the result of conv.
1423+
auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1);
1424+
if (failed(squeezeOutputInfo)) {
1425+
return rewriter.notifyMatchFailure(op,
1426+
"cannot generate squeeze tensor");
1427+
}
1428+
conv = squeezeOutputInfo.value();
1429+
}
13741430
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
13751431
return success();
13761432
}

lib/Conversion/Utils/Utils.cpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,119 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
447447
return castIntToIndex(rewriter, loc, boundedByDimSize);
448448
}
449449

450+
// Helper function to unsqueeze the input tensor at given dim.
451+
// Returns the unsqueezed tensor or failure.
452+
FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
453+
Value input, int64_t dim) {
454+
auto inputType = cast<RankedTensorType>(input.getType());
455+
int64_t inputRank = inputType.getRank();
456+
ArrayRef<int64_t> inputShape = inputType.getShape();
457+
458+
// `input` has a reduced rank. Hence add 1.
459+
int64_t unsqueezedRank = inputShape.size() + 1;
460+
dim = toPositiveDim(dim, unsqueezedRank);
461+
if (!isValidDim(dim, unsqueezedRank)) {
462+
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
463+
}
464+
465+
SmallVector<int64_t> unsqueezedShape{inputShape};
466+
unsqueezedShape.insert(unsqueezedShape.begin() + dim, 1);
467+
Type unsqueezedType =
468+
RankedTensorType::get(unsqueezedShape, inputType.getElementType());
469+
470+
SmallVector<ReassociationIndices> reassociationMap(inputRank);
471+
// From the perspective of the reassociation map, the situation of
472+
// unsqueezing before or after the last dimension is symmetrical.
473+
// Normalize it to the "before" case.
474+
// The 0 case is special here, since there is no last dimension to insert
475+
// before -- we simply rely on the loop below iterating 0 times.
476+
if (dim == inputRank && inputRank != 0)
477+
dim = inputRank - 1;
478+
bool alreadyCrossedExpandedDim = false;
479+
for (int i = 0; i != inputRank; i++) {
480+
if (alreadyCrossedExpandedDim) {
481+
reassociationMap[i].push_back(i + 1);
482+
} else {
483+
reassociationMap[i].push_back(i);
484+
if (i == dim) {
485+
reassociationMap[i].push_back(i + 1);
486+
alreadyCrossedExpandedDim = true;
487+
}
488+
}
489+
}
490+
Value unsqueezed = rewriter.create<tensor::ExpandShapeOp>(
491+
op->getLoc(), unsqueezedType, input, reassociationMap);
492+
return unsqueezed;
493+
}
494+
495+
// Helper function to squeeze the input tensor at given dim.
496+
// Returns the squeezed tensor or failure.
497+
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
498+
Value input, int64_t dim) {
499+
Location loc = op->getLoc();
500+
auto inputType = cast<RankedTensorType>(input.getType());
501+
int64_t inputRank = inputType.getRank();
502+
503+
// No scope for squeezing the input.
504+
if (inputRank == 0)
505+
return input;
506+
507+
dim = toPositiveDim(dim, inputRank);
508+
if (!isValidDim(dim, inputRank))
509+
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
510+
511+
// assert dynamic squeeze dim size == 1
512+
if (inputType.isDynamicDim(dim)) {
513+
Value cstDim = rewriter.create<arith::ConstantIndexOp>(loc, dim);
514+
Value dimVal = rewriter.create<tensor::DimOp>(loc, input, cstDim);
515+
Value cstOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
516+
Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
517+
dimVal, cstOne);
518+
rewriter.create<cf::AssertOp>(
519+
loc, cmp,
520+
rewriter.getStringAttr(
521+
"Expected dynamic squeeze dim size to be statically 1"));
522+
}
523+
524+
ArrayRef<int64_t> inputShape = inputType.getShape();
525+
SmallVector<int64_t> squeezedShape;
526+
squeezedShape.append(inputShape.begin(), inputShape.begin() + dim);
527+
squeezedShape.append(inputShape.begin() + dim + 1, inputShape.end());
528+
int64_t squeezedRank = inputRank - 1;
529+
Type squeezedType =
530+
RankedTensorType::get(squeezedShape, inputType.getElementType());
531+
532+
// If the dim(th) dimension of operand tensor type is not statically unit,
533+
// squeeze will behave as an identity operation.
534+
if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) {
535+
return input;
536+
}
537+
538+
SmallVector<ReassociationIndices> reassociationMap(squeezedRank);
539+
bool alreadyCrossedSqueezedDim = false;
540+
for (int i = 0; i != squeezedRank; i++) {
541+
if (alreadyCrossedSqueezedDim) {
542+
reassociationMap[i].push_back(i + 1);
543+
} else {
544+
reassociationMap[i].push_back(i);
545+
if (dim != 0 && i != dim - 1)
546+
continue;
547+
548+
alreadyCrossedSqueezedDim = true;
549+
if (dim == 0)
550+
reassociationMap[0].push_back(1);
551+
if (i == dim - 1)
552+
reassociationMap[i].push_back(dim);
553+
}
554+
}
555+
// Note: In case the operand tensor type is of unit rank and is statically
556+
// shaped with unit dimension, the `reassociationMap` will be empty and the
557+
// input will be collapsed to a 0-D tensor.
558+
Value squeezed = rewriter.create<tensor::CollapseShapeOp>(
559+
op->getLoc(), squeezedType, input, reassociationMap);
560+
return squeezed;
561+
}
562+
450563
} // namespace Torch
451564
} // namespace torch
452565
} // namespace mlir

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2731,6 +2731,7 @@
27312731
"ElementwiseBitwiseAndScalarInt64Module_basic",
27322732
"ElementwiseBitwiseAndScalarInt32Module_basic",
27332733
"ElementwiseBitwiseAndScalarInt8Module_basic",
2734+
"Conv1dGroupModule_basic",
27342735
"Conv2dQInt8Module_basic",
27352736
"Conv2dQInt8Module_depthwise",
27362737
"Conv2dQInt8Module_grouped",
@@ -2886,6 +2887,7 @@
28862887
"Conv1dModule_basic",
28872888
"Conv1dWithSamePaddingModule_basic",
28882889
"Conv1dWithValidPaddingModule_basic",
2890+
"Conv1dGroupModule_basic",
28892891
"Conv2dBiasNoPaddingModule_basic",
28902892
"Conv2dModule_basic",
28912893
"Conv2dNoPaddingModule_basic",
@@ -3593,6 +3595,7 @@
35933595
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
35943596
"Conv1dWithSamePaddingModule_basic",
35953597
"Conv1dWithValidPaddingModule_basic",
3598+
"Conv1dGroupModule_basic",
35963599
"Conv2dQInt8Module_basic",
35973600
"Conv2dQInt8Module_depthwise",
35983601
"Conv2dQInt8Module_grouped",
@@ -4186,6 +4189,7 @@
41864189
"Conv1dWithSamePaddingModule_basic",
41874190
"Conv1dWithValidPaddingModule_basic",
41884191
"Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic",
4192+
"Conv1dGroupModule_basic",
41894193
"Conv2dBiasNoPaddingModule_basic",
41904194
"Conv2dModule_basic",
41914195
"Conv2dNoPaddingModule_basic",

0 commit comments

Comments
 (0)