Skip to content

Commit 98c6971

Browse files
BaneTrifaBranko Trifkovic
and
Branko Trifkovic
authored
Implement lowering of torch.aten.triu_indices (#3451)
Closes [nod-ai/SHARK-ModelDev/issues/709](nod-ai/SHARK-ModelDev#709) --------- Co-authored-by: Branko Trifkovic <[email protected]>
1 parent acd57a3 commit 98c6971

File tree

9 files changed

+545
-0
lines changed

9 files changed

+545
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15517,6 +15517,36 @@ def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [
1551715517
let hasCanonicalizer = 1;
1551815518
}
1551915519

15520+
def Torch_AtenTriuIndicesOp : Torch_Op<"aten.triu_indices", [
15521+
AllowsTypeRefinement,
15522+
HasValueSemantics,
15523+
ReadOnly
15524+
]> {
15525+
let summary = "Generated op for `aten::triu_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)`";
15526+
let arguments = (ins
15527+
Torch_IntType:$row,
15528+
Torch_IntType:$col,
15529+
Torch_IntType:$offset,
15530+
AnyTorchOptionalIntType:$dtype,
15531+
AnyTorchOptionalIntType:$layout,
15532+
AnyTorchOptionalDeviceType:$device,
15533+
AnyTorchOptionalBoolType:$pin_memory
15534+
);
15535+
let results = (outs
15536+
AnyTorchOptionalTensorType:$result
15537+
);
15538+
let hasCustomAssemblyFormat = 1;
15539+
let extraClassDefinition = [{
15540+
ParseResult AtenTriuIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
15541+
return parseDefaultTorchOp(parser, result, 7, 1);
15542+
}
15543+
void AtenTriuIndicesOp::print(OpAsmPrinter &printer) {
15544+
printDefaultTorchOp(printer, *this, 7, 1);
15545+
}
15546+
}];
15547+
let hasVerifier = 1;
15548+
}
15549+
1552015550
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
1552115551
AllowsTypeRefinement,
1552215552
HasValueSemantics,

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5212,3 +5212,39 @@ LogicalResult BindSymbolicShapeOp::verify() {
52125212

52135213
return success();
52145214
}
5215+
// AtenTriuIndicesOp
5216+
//===----------------------------------------------------------------------===//
5217+
5218+
LogicalResult AtenTriuIndicesOp::verify() {
5219+
5220+
// Check if row, col and offset are constant ints
5221+
int64_t row;
5222+
if (!matchPattern(getRow(), m_TorchConstantInt(&row)))
5223+
return success();
5224+
5225+
int64_t col;
5226+
if (!matchPattern(getCol(), m_TorchConstantInt(&col)))
5227+
return success();
5228+
5229+
int64_t offset;
5230+
if (!matchPattern(getOffset(), m_TorchConstantInt(&offset)))
5231+
return success();
5232+
5233+
// Check if values of row, and col are valid
5234+
if (row < 0)
5235+
return emitOpError("row must be non-negative, got ") << row;
5236+
5237+
if (col < 0)
5238+
return emitOpError("col must be non-negative, got ") << col;
5239+
5240+
// Check if dtype is valid
5241+
int64_t dtype;
5242+
if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype)))
5243+
return success();
5244+
if (dtype != (int)torch_upstream::ScalarType::Int &&
5245+
dtype != (int)torch_upstream::ScalarType::Long)
5246+
return emitOpError(
5247+
"'triu_indices' implemented only for torch.int32 and torch.int64");
5248+
5249+
return success();
5250+
}

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9729,6 +9729,68 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
97299729
" %1 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4, %arg6, %0) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.int, !torch.optional<list<int>>, !torch.optional<int>) -> !torch.tuple<list<int>, list<int>, list<int>, list<int>>\n"
97309730
" return %1 : !torch.tuple<list<int>, list<int>, list<int>, list<int>>\n"
97319731
" }\n"
9732+
" func.func @\"__torch_mlir_shape_fn.aten.triu_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
9733+
" %true = torch.constant.bool true\n"
9734+
" %int0 = torch.constant.int 0\n"
9735+
" %int2 = torch.constant.int 2\n"
9736+
" %int1 = torch.constant.int 1\n"
9737+
" %0 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
9738+
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
9739+
" torch.prim.If.yield %true : !torch.bool\n"
9740+
" } else {\n"
9741+
" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
9742+
" torch.prim.If.yield %3 : !torch.bool\n"
9743+
" }\n"
9744+
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
9745+
" %3 = torch.prim.ListConstruct %int2, %int0 : (!torch.int, !torch.int) -> !torch.list<int>\n"
9746+
" torch.prim.If.yield %3 : !torch.list<int>\n"
9747+
" } else {\n"
9748+
" %3 = torch.aten.sub.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n"
9749+
" %4 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
9750+
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
9751+
" torch.prim.If.yield %true : !torch.bool\n"
9752+
" } else {\n"
9753+
" %11 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
9754+
" torch.prim.If.yield %11 : !torch.bool\n"
9755+
" }\n"
9756+
" %6:2 = torch.prim.If %5 -> (!torch.int, !torch.int) {\n"
9757+
" torch.prim.If.yield %int0, %int0 : !torch.int, !torch.int\n"
9758+
" } else {\n"
9759+
" %11 = torch.aten.gt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
9760+
" %12 = torch.prim.If %11 -> (!torch.int) {\n"
9761+
" %27 = torch.aten.add.int %int1, %3 : !torch.int, !torch.int -> !torch.int\n"
9762+
" %28 = torch.prim.min.int %arg1, %27 : !torch.int, !torch.int -> !torch.int\n"
9763+
" torch.prim.If.yield %28 : !torch.int\n"
9764+
" } else {\n"
9765+
" %27 = torch.aten.add.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n"
9766+
" %28 = torch.aten.gt.int %27, %int0 : !torch.int, !torch.int -> !torch.bool\n"
9767+
" %29 = torch.aten.Int.bool %28 : !torch.bool -> !torch.int\n"
9768+
" torch.prim.If.yield %29 : !torch.int\n"
9769+
" }\n"
9770+
" %13 = torch.aten.add.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n"
9771+
" %14 = torch.prim.min.int %arg1, %13 : !torch.int, !torch.int -> !torch.int\n"
9772+
" %15 = torch.prim.max.int %int0, %14 : !torch.int, !torch.int -> !torch.int\n"
9773+
" %16 = torch.aten.add.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n"
9774+
" %17 = torch.prim.min.int %arg0, %16 : !torch.int, !torch.int -> !torch.int\n"
9775+
" %18 = torch.prim.max.int %int0, %17 : !torch.int, !torch.int -> !torch.int\n"
9776+
" %19 = torch.aten.sub.int %15, %12 : !torch.int, !torch.int -> !torch.int\n"
9777+
" %20 = torch.aten.add.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n"
9778+
" %21 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int\n"
9779+
" %22 = torch.aten.mul.int %21, %20 : !torch.int, !torch.int -> !torch.int\n"
9780+
" %23 = torch.aten.floordiv.int %22, %int2 : !torch.int, !torch.int -> !torch.int\n"
9781+
" %24 = torch.aten.sub.int %18, %20 : !torch.int, !torch.int -> !torch.int\n"
9782+
" %25 = torch.aten.mul.int %24, %arg1 : !torch.int, !torch.int -> !torch.int\n"
9783+
" %26 = torch.prim.max.int %int0, %25 : !torch.int, !torch.int -> !torch.int\n"
9784+
" torch.prim.If.yield %23, %26 : !torch.int, !torch.int\n"
9785+
" }\n"
9786+
" %7 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int\n"
9787+
" %8 = torch.aten.add.int %6#0, %6#1 : !torch.int, !torch.int -> !torch.int\n"
9788+
" %9 = torch.aten.sub.int %7, %8 : !torch.int, !torch.int -> !torch.int\n"
9789+
" %10 = torch.prim.ListConstruct %int2, %9 : (!torch.int, !torch.int) -> !torch.list<int>\n"
9790+
" torch.prim.If.yield %10 : !torch.list<int>\n"
9791+
" }\n"
9792+
" return %2 : !torch.list<int>\n"
9793+
" }\n"
97329794
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
97339795
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
97349796
" return %0 : !torch.tuple<list<int>, list<int>>\n"
@@ -14023,6 +14085,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1402314085
" %int6 = torch.constant.int 6\n"
1402414086
" return %int6 : !torch.int\n"
1402514087
" }\n"
14088+
" func.func @\"__torch_mlir_dtype_fn.aten.triu_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
14089+
" %int4 = torch.constant.int 4\n"
14090+
" %none = torch.constant.none\n"
14091+
" %0 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
14092+
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
14093+
" torch.prim.If.yield %int4 : !torch.int\n"
14094+
" } else {\n"
14095+
" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
14096+
" torch.prim.If.yield %2 : !torch.int\n"
14097+
" }\n"
14098+
" return %1 : !torch.int\n"
14099+
" }\n"
1402614100
" func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1402714101
" %int3 = torch.constant.int 3\n"
1402814102
" %int1 = torch.constant.int 1\n"

0 commit comments

Comments
 (0)