Skip to content

Commit fd48b5b

Browse files
stbaioneTimAtGoogle
authored andcommitted
support aten._trilinear and improve einsum decomposition (#3784)
# Tracking [Issue](nod-ai/SHARK-ModelDev#848) [TorchToLinalg Op Support](nod-ai/SHARK-ModelDev#347) # Description Aten_TrilinearOp is an implementation of a "trilinear einstein sum". Essentially, just an einsum across 3 tensors. There are a few inputs: ## Tensor Inputs - i1, i2, i3 - The three input tensors for the _trilinear op. ## Expands These inputs allow you to unsqueeze an input tensor at the specified dims as a pre-processing step to make the shapes compatible for the rest of the op: - expand1: List[int], expand2: List[int], expand3: List[int] ## sumdim - sumdim: List[int] - After applying element wise multiplication, the values in sumdim denote where to collapse a dimension by summing over it ## unroll_dim - unroll_dim: int - In the PyTorch implementation, this specifies a dimension where you could slice the input tensors, multiply and sum them, then concatenate the results in an output tensor. This complicates the implementation significantly, but doesn't change the result, so I opted against it. Along with that, a previously accepted path for solving this involved reusing the AtenEinsumOp, which also would also ignore this input. # Solution After trying a bunch of more complicated approaches for it, this op actually ended up being quite simple: [See _trilinear](https://dev-discuss.pytorch.org/t/defining-the-core-aten-opset/1464) `_trilinear = (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) * i3.unsqueeze(expand3)).sum(sumdim)` Wish I saw this earlier, but watcha gonna do: 🙃 ## Not Reusing AtenEinsumOp Frankly, I found multiple cases where valid inputs would have numerical mismatches for EinsumOp, even when running tests against EinsumOp directly. I think it has something to do with the singleton dimensions. Will need to look into this further, but once I realized the simplified approach, it appeared to be more reliable and much simpler. Either way (credit to @zjgarvey), there are improvements to the einsum op here. When I was originally trying to use the op, intermediate tensors were being flattened properly, but then its 0th dimension was being cast from a static dim to a dynamic dim due to integers not folding correctly in the MLIR. Figured it's worth keeping these improvements for future reusers of EinsumOp. # The zero'd out dim "bug" For some reason, if you specify a dimension in all `expands`, ```i.e. [expand1=[0], expand2=[0], expand3=[0]], [expand1=[1], expand2=[1], expand3=[1]] ``` The _trilinear op would specify `0` for that dimension in the output shape, unless it was also included in `sumdim`. This goes against the implementation of torch.einsum: ``` >>> a, b, c = [torch.rand(1, 3, 3, 3) for i in range(3)] # Simulate expand at dim=0 for all input tensors >>> torch.einsum('abcd,abcd,abcd->abcd', a, b, c).shape torch.Size([1, 3, 3, 3]) ``` And is just straight up incorrect mathematically. I considered "replacing" singleton dims with zeroed out dims, but that seemed like carrying over a bug. Instead, I included a test for the case, verified that the singleton dimensions were handled the way that torch.einsum handles it, instead of torch._trilinear, and xfailed it with a note as to why.
1 parent acb8198 commit fd48b5b

File tree

8 files changed

+553
-14
lines changed

8 files changed

+553
-14
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14248,6 +14248,36 @@ def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [
1424814248
}];
1424914249
}
1425014250

14251+
def Torch_Aten_TrilinearOp : Torch_Op<"aten._trilinear", [
14252+
AllowsTypeRefinement,
14253+
HasValueSemantics,
14254+
ReadOnly
14255+
]> {
14256+
let summary = "Generated op for `aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)`";
14257+
let arguments = (ins
14258+
AnyTorchTensorType:$i1,
14259+
AnyTorchTensorType:$i2,
14260+
AnyTorchTensorType:$i3,
14261+
AnyTorchListOfTorchIntType:$expand1,
14262+
AnyTorchListOfTorchIntType:$expand2,
14263+
AnyTorchListOfTorchIntType:$expand3,
14264+
AnyTorchListOfTorchIntType:$sumdim,
14265+
Torch_IntType:$unroll_dim
14266+
);
14267+
let results = (outs
14268+
AnyTorchOptionalTensorType:$result
14269+
);
14270+
let hasCustomAssemblyFormat = 1;
14271+
let extraClassDefinition = [{
14272+
ParseResult Aten_TrilinearOp::parse(OpAsmParser &parser, OperationState &result) {
14273+
return parseDefaultTorchOp(parser, result, 8, 1);
14274+
}
14275+
void Aten_TrilinearOp::print(OpAsmPrinter &printer) {
14276+
printDefaultTorchOp(printer, *this, 8, 1);
14277+
}
14278+
}];
14279+
}
14280+
1425114281
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
1425214282
AllowsTypeRefinement,
1425314283
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8864,6 +8864,112 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
88648864
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
88658865
" return %0 : !torch.list<int>\n"
88668866
" }\n"
8867+
" func.func @\"__torch_mlir_shape_fn.aten._trilinear\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.int) -> !torch.list<int> {\n"
8868+
" %int3 = torch.constant.int 3\n"
8869+
" %int-1 = torch.constant.int -1\n"
8870+
" %str = torch.constant.str \"AssertionError: number of dimensions must match\"\n"
8871+
" %str_0 = torch.constant.str \"expand dimension {} is out of bounds for input of shape {}\"\n"
8872+
" %true = torch.constant.bool true\n"
8873+
" %none = torch.constant.none\n"
8874+
" %str_1 = torch.constant.str \"AssertionError: \"\n"
8875+
" %str_2 = torch.constant.str \"unroll_dim must be in [0, {}]\"\n"
8876+
" %false = torch.constant.bool false\n"
8877+
" %int0 = torch.constant.int 0\n"
8878+
" %int1 = torch.constant.int 1\n"
8879+
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
8880+
" %1 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int\n"
8881+
" %2 = torch.aten.add.int %0, %1 : !torch.int, !torch.int -> !torch.int\n"
8882+
" %3 = torch.aten.ge.int %arg7, %int0 : !torch.int, !torch.int -> !torch.bool\n"
8883+
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
8884+
" %23 = torch.aten.lt.int %arg7, %2 : !torch.int, !torch.int -> !torch.bool\n"
8885+
" torch.prim.If.yield %23 : !torch.bool\n"
8886+
" } else {\n"
8887+
" torch.prim.If.yield %false : !torch.bool\n"
8888+
" }\n"
8889+
" torch.prim.If %4 -> () {\n"
8890+
" torch.prim.If.yield\n"
8891+
" } else {\n"
8892+
" %23 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n"
8893+
" %24 = torch.aten.format(%str_2, %23) : !torch.str, !torch.int -> !torch.str\n"
8894+
" %25 = torch.aten.add.str %str_1, %24 : !torch.str, !torch.str -> !torch.str\n"
8895+
" torch.prim.RaiseException %25, %none : !torch.str, !torch.none\n"
8896+
" torch.prim.If.yield\n"
8897+
" }\n"
8898+
" %5 = call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
8899+
" %6 = call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
8900+
" %7 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list<int>) -> !torch.list<int>\n"
8901+
" %8 = torch.prim.ListConstruct %5, %6, %7 : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<list<int>>\n"
8902+
" %9 = torch.prim.ListConstruct %arg3, %arg4, %arg5 : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<list<int>>\n"
8903+
" torch.prim.Loop %int3, %true, init() {\n"
8904+
" ^bb0(%arg8: !torch.int):\n"
8905+
" %23 = torch.aten.__getitem__.t %9, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
8906+
" %24 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
8907+
" %25 = torch.aten.len.t %24 : !torch.list<int> -> !torch.int\n"
8908+
" %26 = torch.aten.len.t %23 : !torch.list<int> -> !torch.int\n"
8909+
" torch.prim.Loop %26, %true, init() {\n"
8910+
" ^bb0(%arg9: !torch.int):\n"
8911+
" %27 = torch.aten.__getitem__.t %23, %arg9 : !torch.list<int>, !torch.int -> !torch.int\n"
8912+
" %28 = torch.aten.le.int %27, %25 : !torch.int, !torch.int -> !torch.bool\n"
8913+
" torch.prim.If %28 -> () {\n"
8914+
" torch.prim.If.yield\n"
8915+
" } else {\n"
8916+
" %30 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
8917+
" %31 = torch.aten.format(%str_0, %27, %30) : !torch.str, !torch.int, !torch.list<int> -> !torch.str\n"
8918+
" %32 = torch.aten.add.str %str_1, %31 : !torch.str, !torch.str -> !torch.str\n"
8919+
" torch.prim.RaiseException %32, %none : !torch.str, !torch.none\n"
8920+
" torch.prim.If.yield\n"
8921+
" }\n"
8922+
" %29 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
8923+
" torch.aten.insert.t %29, %27, %int1 : !torch.list<int>, !torch.int, !torch.int\n"
8924+
" torch.prim.Loop.condition %true, iter()\n"
8925+
" } : (!torch.int, !torch.bool) -> ()\n"
8926+
" torch.prim.Loop.condition %true, iter()\n"
8927+
" } : (!torch.int, !torch.bool) -> ()\n"
8928+
" %10 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int\n"
8929+
" %11 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int\n"
8930+
" %12 = torch.aten.eq.int %10, %11 : !torch.int, !torch.int -> !torch.bool\n"
8931+
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
8932+
" %23 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int\n"
8933+
" %24 = torch.aten.len.t %7 : !torch.list<int> -> !torch.int\n"
8934+
" %25 = torch.aten.eq.int %23, %24 : !torch.int, !torch.int -> !torch.bool\n"
8935+
" torch.prim.If.yield %25 : !torch.bool\n"
8936+
" } else {\n"
8937+
" torch.prim.If.yield %false : !torch.bool\n"
8938+
" }\n"
8939+
" torch.prim.If %13 -> () {\n"
8940+
" torch.prim.If.yield\n"
8941+
" } else {\n"
8942+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
8943+
" torch.prim.If.yield\n"
8944+
" }\n"
8945+
" %14 = call @__torch__.torch.jit._shape_functions.broadcast_three(%5, %6, %7) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
8946+
" %15 = torch.prim.ListConstruct %false : (!torch.bool) -> !torch.list<bool>\n"
8947+
" %16 = torch.aten.len.t %14 : !torch.list<int> -> !torch.int\n"
8948+
" %17 = torch.operator \"aten.mul.left_t\"(%15, %16) : (!torch.list<bool>, !torch.int) -> !torch.list<bool> \n"
8949+
" %18 = torch.aten.len.t %arg6 : !torch.list<int> -> !torch.int\n"
8950+
" torch.prim.Loop %18, %true, init() {\n"
8951+
" ^bb0(%arg8: !torch.int):\n"
8952+
" %23 = torch.aten.__getitem__.t %arg6, %arg8 : !torch.list<int>, !torch.int -> !torch.int\n"
8953+
" %24 = torch.aten._set_item.t %17, %23, %true : !torch.list<bool>, !torch.int, !torch.bool -> !torch.list<bool>\n"
8954+
" torch.prim.Loop.condition %true, iter()\n"
8955+
" } : (!torch.int, !torch.bool) -> ()\n"
8956+
" %19 = torch.aten.len.t %14 : !torch.list<int> -> !torch.int\n"
8957+
" %20 = torch.aten.sub.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n"
8958+
" %21 = torch.aten.__range_length %20, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
8959+
" %22 = torch.prim.Loop %21, %true, init(%14) {\n"
8960+
" ^bb0(%arg8: !torch.int, %arg9: !torch.list<int>):\n"
8961+
" %23 = torch.aten.__derive_index %arg8, %20, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
8962+
" %24 = torch.aten.__getitem__.t %17, %23 : !torch.list<bool>, !torch.int -> !torch.bool\n"
8963+
" %25 = torch.prim.If %24 -> (!torch.list<int>) {\n"
8964+
" %26 = func.call @__torch__.torch.jit._shape_functions._reduce_along_dim(%arg9, %23, %false) : (!torch.list<int>, !torch.int, !torch.bool) -> !torch.list<int>\n"
8965+
" torch.prim.If.yield %26 : !torch.list<int>\n"
8966+
" } else {\n"
8967+
" torch.prim.If.yield %arg9 : !torch.list<int>\n"
8968+
" }\n"
8969+
" torch.prim.Loop.condition %true, iter(%25 : !torch.list<int>)\n"
8970+
" } : (!torch.int, !torch.bool, !torch.list<int>) -> !torch.list<int>\n"
8971+
" return %22 : !torch.list<int>\n"
8972+
" }\n"
88678973
" func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>, %arg7: !torch.bool) -> !torch.list<int> {\n"
88688974
" %int-1 = torch.constant.int -1\n"
88698975
" %0 = torch.aten.__getitem__.t %arg2, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
@@ -15294,6 +15400,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1529415400
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
1529515401
" return %4 : !torch.int\n"
1529615402
" }\n"
15403+
" func.func @\"__torch_mlir_dtype_fn.aten._trilinear\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.int) -> !torch.int {\n"
15404+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15405+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15406+
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15407+
" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<optional<int>>\n"
15408+
" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
15409+
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
15410+
" return %5 : !torch.int\n"
15411+
" }\n"
1529715412
" func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list<tuple<int, int>>, %arg1: !torch.int) -> !torch.int {\n"
1529815413
" %true = torch.constant.bool true\n"
1529915414
" %none = torch.constant.none\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 152 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "PassDetail.h"
1111

12+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1213
#include "mlir/IR/BuiltinDialect.h"
1314
#include "mlir/Transforms/DialectConversion.h"
1415
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -399,9 +400,9 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
399400
auto inputType = cast<ValueTensorType>(input.getType());
400401
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
401402
reduceDimsLength;
402-
SmallVector<Value> inputShapeTensor;
403+
SmallVector<OpFoldResult> inputShapeTensor;
403404
for (auto i = 0; i < inputRank; ++i) {
404-
inputShapeTensor.emplace_back(rewriter.create<AtenSizeIntOp>(
405+
inputShapeTensor.emplace_back(rewriter.createOrFold<AtenSizeIntOp>(
405406
loc, input,
406407
rewriter.create<Torch::ConstantIntOp>(loc,
407408
rewriter.getI64IntegerAttr(i))));
@@ -412,13 +413,23 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
412413
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
413414
auto dimOffset = 0;
414415

416+
auto materializeIntFold = [&](OpFoldResult thing) {
417+
if (auto attr = dyn_cast<mlir::Attribute>(thing)) {
418+
Value result = rewriter.create<Torch::ConstantIntOp>(
419+
loc, cast<mlir::IntegerAttr>(attr));
420+
return result;
421+
}
422+
return cast<mlir::Value>(thing);
423+
};
424+
415425
auto appendDims = [&](int64_t dimLength) {
416-
Value prod = constOne;
426+
OpFoldResult prod = getAsOpFoldResult(constOne);
417427
for (auto i = 0; i < dimLength; ++i) {
418-
prod = rewriter.create<AtenMulIntOp>(loc, prod,
419-
inputShapeTensor[i + dimOffset]);
428+
prod = rewriter.createOrFold<AtenMulIntOp>(
429+
loc, materializeIntFold(prod),
430+
materializeIntFold(inputShapeTensor[i + dimOffset]));
420431
}
421-
outShapeTensor.emplace_back(prod);
432+
outShapeTensor.emplace_back(materializeIntFold(prod));
422433
dimOffset += dimLength;
423434
};
424435

@@ -570,21 +581,32 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
570581
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
571582
: rhsType.getOptionalDtype();
572583

584+
auto materializeIntFold = [&](OpFoldResult thing) {
585+
if (auto attr = dyn_cast<mlir::Attribute>(thing)) {
586+
Value result = rewriter.create<Torch::ConstantIntOp>(
587+
loc, cast<mlir::IntegerAttr>(attr));
588+
return result;
589+
}
590+
return cast<mlir::Value>(thing);
591+
};
592+
573593
llvm::SmallDenseMap<char, Value> lhsDimShapeMap;
574594
for (size_t idx = 0; idx < lhsTokens.size(); ++idx) {
575595
char d = lhsTokens[idx];
576-
lhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
596+
OpFoldResult lhsFold = rewriter.createOrFold<AtenSizeIntOp>(
577597
loc, lhs,
578598
rewriter.create<Torch::ConstantIntOp>(loc,
579599
rewriter.getI64IntegerAttr(idx)));
600+
lhsDimShapeMap[d] = materializeIntFold(lhsFold);
580601
}
581602
llvm::SmallDenseMap<char, Value> rhsDimShapeMap;
582603
for (size_t idx = 0; idx < rhsTokens.size(); ++idx) {
583604
char d = rhsTokens[idx];
584-
rhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
605+
OpFoldResult rhsFold = rewriter.createOrFold<AtenSizeIntOp>(
585606
loc, rhs,
586607
rewriter.create<Torch::ConstantIntOp>(loc,
587608
rewriter.getI64IntegerAttr(idx)));
609+
rhsDimShapeMap[d] = materializeIntFold(rhsFold);
588610
}
589611

590612
// parse batch, contracting, other, reduce dims of lhs and rhs
@@ -604,8 +626,9 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
604626
bool lhsContains = lhsDimShapeMap.count(d) > 0;
605627
bool rhsContains = rhsDimShapeMap.count(d) > 0;
606628
if (lhsContains && rhsContains) {
607-
outDimShapeMap[d] = rewriter.create<Torch::PrimMaxIntOp>(
629+
OpFoldResult out = rewriter.createOrFold<Torch::PrimMaxIntOp>(
608630
loc, lhsDimShapeMap[d], rhsDimShapeMap[d]);
631+
outDimShapeMap[d] = materializeIntFold(out);
609632
} else if (lhsContains) {
610633
outDimShapeMap[d] = lhsDimShapeMap[d];
611634
} else if (rhsContains) {
@@ -1973,6 +1996,125 @@ class DecomposeAtenEinsumOp : public OpRewritePattern<AtenEinsumOp> {
19731996
};
19741997
} // namespace
19751998

1999+
namespace {
2000+
// Trilinear einstein sum, decomposed to:
2001+
// (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) * i3.unsqueeze(expand3))
2002+
// .sum(sumdim)
2003+
// The unrollDim operand does not impact the output of the operation, so
2004+
// it is ignored.
2005+
2006+
class DecomposeAten_TrilinearOp : public OpRewritePattern<Aten_TrilinearOp> {
2007+
public:
2008+
using OpRewritePattern::OpRewritePattern;
2009+
LogicalResult matchAndRewrite(Aten_TrilinearOp op,
2010+
PatternRewriter &rewriter) const override {
2011+
2012+
Location loc = op.getLoc();
2013+
2014+
Value input1 = op.getI1();
2015+
Value input2 = op.getI2();
2016+
Value input3 = op.getI3();
2017+
2018+
// Expansions
2019+
SmallVector<int64_t> expand1;
2020+
SmallVector<int64_t> expand2;
2021+
SmallVector<int64_t> expand3;
2022+
if (!matchPattern(op.getExpand1(), m_TorchListOfConstantInts(expand1))) {
2023+
return rewriter.notifyMatchFailure(op, "expand1 should be constant");
2024+
}
2025+
if (!matchPattern(op.getExpand2(), m_TorchListOfConstantInts(expand2))) {
2026+
return rewriter.notifyMatchFailure(op, "expand2 should be constant");
2027+
}
2028+
if (!matchPattern(op.getExpand3(), m_TorchListOfConstantInts(expand3))) {
2029+
return rewriter.notifyMatchFailure(op, "expand3 should be constant");
2030+
}
2031+
2032+
SmallVector<int64_t> sumDim;
2033+
if (!matchPattern(op.getSumdim(), m_TorchListOfConstantInts(sumDim))) {
2034+
return rewriter.notifyMatchFailure(op, "sumDim should be constant");
2035+
}
2036+
2037+
// Check if there are any dimensions that intersect between expand1,
2038+
// expand2, and expand3.
2039+
int64_t totalDims =
2040+
cast<BaseTensorType>(input1.getType()).getSizes().size() +
2041+
expand1.size();
2042+
if (sharedExpandDims(totalDims, expand1, expand2, expand3, sumDim)) {
2043+
// pytorch issue filed: https://github.com/pytorch/pytorch/issues/138353
2044+
// TODO: Remove warning when issue gets resolved.
2045+
op->emitWarning("aten::_trilinear implementation in this case is "
2046+
"non-functional (returns an empty dimension). We will "
2047+
"intentionally deviate from this behavior.");
2048+
}
2049+
2050+
// Apply unsqueeze to respective input tensors at the specified dimensions
2051+
SmallVector<int64_t> sortedExpand1 = expand1;
2052+
std::sort(sortedExpand1.begin(), sortedExpand1.end());
2053+
for (auto expand : sortedExpand1) {
2054+
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
2055+
loc, rewriter.getI64IntegerAttr(expand));
2056+
input1 = *unsqueezeTensor(rewriter, op, input1, expandDim);
2057+
}
2058+
SmallVector<int64_t> sortedExpand2 = expand2;
2059+
std::sort(sortedExpand2.begin(), sortedExpand2.end());
2060+
for (auto expand : sortedExpand2) {
2061+
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
2062+
loc, rewriter.getI64IntegerAttr(expand));
2063+
input2 = *unsqueezeTensor(rewriter, op, input2, expandDim);
2064+
}
2065+
SmallVector<int64_t> sortedExpand3 = expand3;
2066+
std::sort(sortedExpand3.begin(), sortedExpand3.end());
2067+
for (auto expand : sortedExpand3) {
2068+
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
2069+
loc, rewriter.getI64IntegerAttr(expand));
2070+
input3 = *unsqueezeTensor(rewriter, op, input3, expandDim);
2071+
}
2072+
2073+
// Apply multiplication operation.
2074+
auto mul1 =
2075+
rewriter.create<AtenMulTensorOp>(loc, op.getType(), input1, input2);
2076+
auto mul2 =
2077+
rewriter.create<AtenMulTensorOp>(loc, op.getType(), mul1, input3);
2078+
2079+
// Apply sum operation.
2080+
// Parse sumDim in descending order to avoid any issues with the
2081+
// dimensions being removed.
2082+
Value result = mul2;
2083+
SmallVector<int64_t> sortedSumDims = sumDim;
2084+
std::sort(sortedSumDims.rbegin(), sortedSumDims.rend());
2085+
for (int64_t dim : sortedSumDims) {
2086+
Value dimValue = rewriter.create<Torch::ConstantIntOp>(
2087+
loc, rewriter.getI64IntegerAttr(dim));
2088+
result =
2089+
createSumAlongDimension(rewriter, loc, op, result, dimValue, false);
2090+
}
2091+
2092+
rewriter.replaceOp(op, result);
2093+
return success();
2094+
}
2095+
2096+
private:
2097+
// Determine if there are any dimensions that intersect between expand1,
2098+
// expand2, and expand3.
2099+
bool sharedExpandDims(const int64_t &totalDims,
2100+
const SmallVector<int64_t> &expand1,
2101+
const SmallVector<int64_t> &expand2,
2102+
const SmallVector<int64_t> &expand3,
2103+
const SmallVector<int64_t> &sumDim) const {
2104+
for (int64_t i = 0; i < totalDims; ++i) {
2105+
if (!contains(sumDim, i) && contains(expand1, i) &&
2106+
contains(expand2, i) && contains(expand3, i)) {
2107+
return true;
2108+
}
2109+
}
2110+
return false;
2111+
}
2112+
bool contains(const SmallVector<int64_t> &vec, int64_t value) const {
2113+
return std::find(vec.begin(), vec.end(), value) != vec.end();
2114+
}
2115+
};
2116+
} // namespace
2117+
19762118
namespace {
19772119
// Calculate the trace of the input tensor as the sum over its diagonal
19782120
// elements. This computation is performed as:
@@ -10078,6 +10220,7 @@ class DecomposeComplexOpsPass
1007810220
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
1007910221
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
1008010222
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
10223+
addPatternIfTargetOpIsIllegal<DecomposeAten_TrilinearOp>(patterns);
1008110224
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
1008210225
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
1008310226
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);

0 commit comments

Comments
 (0)