Skip to content

Commit e81282a

Browse files
authored
Support for prims collapse op (lowering to linalg) (#2572)
Steps taken: 1) add generator code to torch_ods_gen.py, run update_torch_ods.sh 2) add (custom) shape and type inference generator code to abstract_interp_lib_gen.py, run update_abstract_interp_lib.sh 3) Implement lowering to tensor.collapse_dims. Requires the `start` and `end` values to be constant, else lowering fails 4) Update xfail_sets.py (append to LTC_XFAIL_SET) after running /tools/e2e_test.sh --filter Collapse --verbose -c XX for all support backends (XX). Motivation: - Supporting the collapse operation will be useful for lowering of pixel_shuffle (see Issue #2559)
1 parent 6be9789 commit e81282a

File tree

8 files changed

+323
-0
lines changed

8 files changed

+323
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14185,6 +14185,31 @@ def Torch_PrimsSqrtOp : Torch_Op<"prims.sqrt", [
1418514185
}];
1418614186
}
1418714187

14188+
def Torch_PrimsCollapseOp : Torch_Op<"prims.collapse", [
14189+
AllowsTypeRefinement,
14190+
HasValueSemantics,
14191+
ReadOnly
14192+
]> {
14193+
let summary = "Generated op for `prims::collapse : (Tensor, int, int) -> (Tensor)`";
14194+
let arguments = (ins
14195+
AnyTorchTensorType:$a,
14196+
Torch_IntType:$start,
14197+
Torch_IntType:$end
14198+
);
14199+
let results = (outs
14200+
AnyTorchTensorType:$result
14201+
);
14202+
let hasCustomAssemblyFormat = 1;
14203+
let extraClassDefinition = [{
14204+
ParseResult PrimsCollapseOp::parse(OpAsmParser &parser, OperationState &result) {
14205+
return parseDefaultTorchOp(parser, result, 3, 1);
14206+
}
14207+
void PrimsCollapseOp::print(OpAsmPrinter &printer) {
14208+
printDefaultTorchOp(printer, *this, 3, 1);
14209+
}
14210+
}];
14211+
}
14212+
1418814213
def Torch_PrimsSqueezeOp : Torch_Op<"prims.squeeze", [
1418914214
AllowsTypeRefinement,
1419014215
ReadOnly

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
2626
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
2727
#include "llvm/ADT/APSInt.h"
28+
#include <numeric>
2829

2930
using namespace mlir;
3031
using namespace mlir::torch;
@@ -1298,6 +1299,7 @@ class ConvertElementwiseOp : public ConversionPattern {
12981299
// nll_loss_forward[i] = -(input[i][indi]);
12991300
// TODO: `weight`operand is still to be taken care of.
13001301
namespace {
1302+
13011303
class ConvertAtenNllLossForwardOp
13021304
: public OpConversionPattern<AtenNllLossForwardOp> {
13031305
public:
@@ -1757,6 +1759,71 @@ class ConvertAtenDetachOp : public OpConversionPattern<AtenDetachOp> {
17571759
};
17581760
} // namespace
17591761

1762+
namespace {
1763+
class ConvertPrimsCollapseOp : public OpConversionPattern<PrimsCollapseOp> {
1764+
public:
1765+
using OpConversionPattern::OpConversionPattern;
1766+
LogicalResult
1767+
matchAndRewrite(PrimsCollapseOp op, OpAdaptor adaptor,
1768+
ConversionPatternRewriter &rewriter) const override {
1769+
1770+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
1771+
return failure();
1772+
1773+
auto aRankedTensorType = adaptor.getA().getType().cast<RankedTensorType>();
1774+
const TypeConverter *typeConverter = getTypeConverter();
1775+
1776+
auto resultRankedTensorType =
1777+
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
1778+
1779+
// Collapse range must be statically known.
1780+
int64_t startInt;
1781+
if (!matchPattern(op.getStart(), m_TorchConstantInt(&startInt)))
1782+
return failure();
1783+
1784+
int64_t endInt;
1785+
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&endInt)))
1786+
return failure();
1787+
1788+
// Upstream MLIR is overly strict -- it fails verification if the
1789+
// collapse_shape is the identity op (i.e. when no dimensions are
1790+
// collapsed). We manually fold this case here.
1791+
if (startInt == endInt) {
1792+
rewriter.replaceOp(op, adaptor.getA());
1793+
return success();
1794+
}
1795+
1796+
SmallVector<ReassociationIndices> associations;
1797+
associations.reserve(resultRankedTensorType.getRank());
1798+
1799+
// An example of is where input shape is [3,4,5,6] and
1800+
// start = 1, and end = 2. The collapsed shape is then [3,4*5,6],
1801+
// with reassociation indices of [0], [1,2], and [3].
1802+
1803+
// Append the singleton dimensions before the collapsed dimensions.
1804+
for (unsigned i = 0; i < startInt; ++i) {
1805+
associations.push_back(ReassociationIndices{i});
1806+
}
1807+
1808+
// Append the collapsed dimensions.
1809+
ReassociationIndices collapseDims(endInt + 1 - startInt);
1810+
std::iota(collapseDims.begin(), collapseDims.end(), startInt);
1811+
associations.push_back(collapseDims);
1812+
1813+
// Append the singleton dimensions after the collapsed dimensions.
1814+
for (int i = endInt + 1; i < aRankedTensorType.getRank(); ++i) {
1815+
associations.push_back(ReassociationIndices{i});
1816+
}
1817+
1818+
1819+
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1820+
op, resultRankedTensorType, adaptor.getA(), associations);
1821+
1822+
return success();
1823+
}
1824+
};
1825+
} // namespace
1826+
17601827
namespace {
17611828
class ConvertTensorStaticInfoCastOp
17621829
: public OpConversionPattern<TensorStaticInfoCastOp> {
@@ -1805,6 +1872,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
18051872
patterns.add<ConvertAtenNllLossForwardOp>(typeConverter, context);
18061873
target.addIllegalOp<AtenBatchNormOp>();
18071874
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
1875+
1876+
target.addIllegalOp<PrimsCollapseOp>();
1877+
patterns.add<ConvertPrimsCollapseOp>(typeConverter, context);
1878+
18081879
target.addIllegalOp<AtenNllLossBackwardOp>();
18091880
patterns.add<ConvertAtenNllLossBackwardOp>(typeConverter, context);
18101881
patterns.add<ConvertTensorStaticInfoCastOp>(typeConverter, context);

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6461,6 +6461,80 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
64616461
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
64626462
" return %0 : !torch.list<int>\n"
64636463
" }\n"
6464+
" func.func @\"__torch_mlir_shape_fn.prims.collapse\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {\n"
6465+
" %true = torch.constant.bool true\n"
6466+
" %str = torch.constant.str \"AssertionError: start must be less than or equal to end\"\n"
6467+
" %str_0 = torch.constant.str \"AssertionError: end out of bounds\"\n"
6468+
" %none = torch.constant.none\n"
6469+
" %str_1 = torch.constant.str \"AssertionError: start out of bounds\"\n"
6470+
" %int0 = torch.constant.int 0\n"
6471+
" %int1 = torch.constant.int 1\n"
6472+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
6473+
" %1 = torch.aten.le.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n"
6474+
" torch.prim.If %1 -> () {\n"
6475+
" torch.prim.If.yield\n"
6476+
" } else {\n"
6477+
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
6478+
" torch.prim.If.yield\n"
6479+
" }\n"
6480+
" %2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
6481+
" %3 = torch.aten.le.int %arg2, %2 : !torch.int, !torch.int -> !torch.bool\n"
6482+
" torch.prim.If %3 -> () {\n"
6483+
" torch.prim.If.yield\n"
6484+
" } else {\n"
6485+
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
6486+
" torch.prim.If.yield\n"
6487+
" }\n"
6488+
" %4 = torch.aten.ge.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
6489+
" torch.prim.If %4 -> () {\n"
6490+
" torch.prim.If.yield\n"
6491+
" } else {\n"
6492+
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
6493+
" torch.prim.If.yield\n"
6494+
" }\n"
6495+
" %5 = torch.aten.ge.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
6496+
" torch.prim.If %5 -> () {\n"
6497+
" torch.prim.If.yield\n"
6498+
" } else {\n"
6499+
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
6500+
" torch.prim.If.yield\n"
6501+
" }\n"
6502+
" %6 = torch.aten.le.int %arg1, %arg2 : !torch.int, !torch.int -> !torch.bool\n"
6503+
" torch.prim.If %6 -> () {\n"
6504+
" torch.prim.If.yield\n"
6505+
" } else {\n"
6506+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
6507+
" torch.prim.If.yield\n"
6508+
" }\n"
6509+
" %7 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
6510+
" torch.prim.Loop %arg1, %true, init() {\n"
6511+
" ^bb0(%arg3: !torch.int):\n"
6512+
" %15 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
6513+
" %16 = torch.aten.append.t %7, %15 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
6514+
" torch.prim.Loop.condition %true, iter()\n"
6515+
" } : (!torch.int, !torch.bool) -> ()\n"
6516+
" %8 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n"
6517+
" %9 = torch.aten.__range_length %arg1, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
6518+
" %10 = torch.prim.Loop %9, %true, init(%int1) {\n"
6519+
" ^bb0(%arg3: !torch.int, %arg4: !torch.int):\n"
6520+
" %15 = torch.aten.__derive_index %arg3, %arg1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
6521+
" %16 = torch.aten.__getitem__.t %arg0, %15 : !torch.list<int>, !torch.int -> !torch.int\n"
6522+
" %17 = torch.aten.mul.int %arg4, %16 : !torch.int, !torch.int -> !torch.int\n"
6523+
" torch.prim.Loop.condition %true, iter(%17 : !torch.int)\n"
6524+
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
6525+
" %11 = torch.aten.append.t %7, %10 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
6526+
" %12 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n"
6527+
" %13 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
6528+
" %14 = torch.aten.__range_length %12, %13, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
6529+
" torch.prim.Loop %14, %true, init() {\n"
6530+
" ^bb0(%arg3: !torch.int):\n"
6531+
" %15 = torch.aten.__derive_index %arg3, %12, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
6532+
" %16 = torch.aten.__getitem__.t %arg0, %15 : !torch.list<int>, !torch.int -> !torch.int\n"
6533+
" %17 = torch.aten.append.t %7, %16 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
6534+
" torch.prim.Loop.condition %true, iter()\n"
6535+
" } : (!torch.int, !torch.bool) -> ()\n"
6536+
" return %7 : !torch.list<int>\n"
6537+
" }\n"
64646538
" func.func @\"__torch_mlir_shape_fn.aten.to.dtype\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {\n"
64656539
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
64666540
" return %0 : !torch.list<int>\n"
@@ -11295,6 +11369,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1129511369
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1129611370
" return %0#1 : !torch.int\n"
1129711371
" }\n"
11372+
" func.func @\"__torch_mlir_dtype_fn.prims.collapse\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n"
11373+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11374+
" return %0#1 : !torch.int\n"
11375+
" }\n"
1129811376
"}\n"
1129911377
"";
1130011378
// clang-format on

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,11 @@
13551355
}
13561356

13571357
LTC_XFAIL_SET = {
1358+
"CollapseAllDimensionsModule_basic",
1359+
"CollapseRank1DynamicModule_basic",
1360+
"CollapseStaticModule_basic",
1361+
"CollapsePartialDynamicModule_basic",
1362+
"CollapseFullDynamicModule_basic",
13581363
"PixelShuffleModuleStaticRank3Int64_basic",
13591364
"PixelShuffleModuleStaticRank4Float32_basic",
13601365
"_Convolution2DAllFalseModule_basic",

projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]:
177177
assert self[dim] % 2 == 0, "glu's dim size must be multiply of 2"
178178
return self[:dim] + [self[dim] // 2] + self[dim+1:]
179179

180+
181+
180182
def aten〇_softmax〡shape(self: List[int], dim: int, half_to_float: bool) -> List[int]:
181183
return upstream_shape_functions.unary(self)
182184

@@ -204,6 +206,40 @@ def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1
204206
def prims〇convert_element_type〡shape(a: List[int], dtype: int) -> List[int]:
205207
return upstream_shape_functions.unary(a)
206208

209+
def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]:
210+
# Obtained through trial and error on a few examples in PyTorch:
211+
assert start <= len(a), "start out of bounds"
212+
assert end <= len(a), "end out of bounds"
213+
assert start >= 0, "start out of bounds"
214+
assert end >= 0, "end out of bounds"
215+
assert start <= end, "start must be less than or equal to end"
216+
217+
# Example:
218+
#
219+
# torch._prims.collapse(torch.empty(2,3,4), 1,2).shape
220+
# is
221+
# torch.Size([2, 12])
222+
223+
collapsed: List[int] = []
224+
for i in range(start):
225+
collapsed.append(a[i])
226+
227+
# For the example, here collapsed is [2]
228+
combined = 1
229+
for i in range(start, end + 1):
230+
combined *= a[i]
231+
232+
collapsed.append(combined)
233+
234+
# For the example, here collapsed is [2, 12]
235+
236+
for i in range(end + 1, len(a)):
237+
collapsed.append(a[i])
238+
239+
# For the example, here collapsed is [2, 12]
240+
241+
return collapsed
242+
207243
def aten〇to〇dtype〡shape(self: List[int], dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> List[int]:
208244
return upstream_shape_functions.unary(self)
209245

@@ -905,6 +941,7 @@ def aten〇squeeze〇dim〡shape(self: List[int], dim: int) -> List[int]:
905941
def prims〇squeeze〡shape(a: List[int], dimensions: List[int]) -> List[int]:
906942
return upstream_shape_functions.squeeze_dims(a, dimensions)
907943

944+
908945
def prims〇view_of〡shape(a: List[int]) -> List[int]:
909946
return a
910947

@@ -3693,6 +3730,12 @@ def prims〇squeeze〡dtype(a_rank_dtype: Tuple[int, int], dimensions: List[int]
36933730
return a_dtype
36943731

36953732

3733+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, start=0, end = 0))
3734+
def prims〇collapse〡dtype(a_rank_dtype: Tuple[int, int], start: int, end: int) -> int:
3735+
a_rank, a_dtype = a_rank_dtype
3736+
return a_dtype
3737+
3738+
36963739

36973740
# ==============================================================================
36983741
# Main

projects/pt1/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,7 @@ def emit_with_mutating_variants(key, **kwargs):
817817
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)")
818818
emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)")
819819
emit("prims::sqrt : (Tensor) -> (Tensor)")
820+
emit("prims::collapse : (Tensor, int, int) -> (Tensor)")
820821
emit("prims::squeeze : (Tensor, int[]) -> (Tensor)")
821822
emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True)
822823

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils):
341341
module.forward(tu.rand(4), tu.rand())
342342

343343

344+
344345
# ==============================================================================
345346

346347

0 commit comments

Comments
 (0)