@@ -9729,6 +9729,68 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
9729
9729
" %1 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4, %arg6, %0) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.int, !torch.optional<list<int>>, !torch.optional<int>) -> !torch.tuple<list<int>, list<int>, list<int>, list<int>>\n"
9730
9730
" return %1 : !torch.tuple<list<int>, list<int>, list<int>, list<int>>\n"
9731
9731
" }\n"
9732
+ " func.func @\"__torch_mlir_shape_fn.aten.triu_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
9733
+ " %true = torch.constant.bool true\n"
9734
+ " %int0 = torch.constant.int 0\n"
9735
+ " %int2 = torch.constant.int 2\n"
9736
+ " %int1 = torch.constant.int 1\n"
9737
+ " %0 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
9738
+ " %1 = torch.prim.If %0 -> (!torch.bool) {\n"
9739
+ " torch.prim.If.yield %true : !torch.bool\n"
9740
+ " } else {\n"
9741
+ " %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
9742
+ " torch.prim.If.yield %3 : !torch.bool\n"
9743
+ " }\n"
9744
+ " %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
9745
+ " %3 = torch.prim.ListConstruct %int2, %int0 : (!torch.int, !torch.int) -> !torch.list<int>\n"
9746
+ " torch.prim.If.yield %3 : !torch.list<int>\n"
9747
+ " } else {\n"
9748
+ " %3 = torch.aten.sub.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n"
9749
+ " %4 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
9750
+ " %5 = torch.prim.If %4 -> (!torch.bool) {\n"
9751
+ " torch.prim.If.yield %true : !torch.bool\n"
9752
+ " } else {\n"
9753
+ " %11 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
9754
+ " torch.prim.If.yield %11 : !torch.bool\n"
9755
+ " }\n"
9756
+ " %6:2 = torch.prim.If %5 -> (!torch.int, !torch.int) {\n"
9757
+ " torch.prim.If.yield %int0, %int0 : !torch.int, !torch.int\n"
9758
+ " } else {\n"
9759
+ " %11 = torch.aten.gt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
9760
+ " %12 = torch.prim.If %11 -> (!torch.int) {\n"
9761
+ " %27 = torch.aten.add.int %int1, %3 : !torch.int, !torch.int -> !torch.int\n"
9762
+ " %28 = torch.prim.min.int %arg1, %27 : !torch.int, !torch.int -> !torch.int\n"
9763
+ " torch.prim.If.yield %28 : !torch.int\n"
9764
+ " } else {\n"
9765
+ " %27 = torch.aten.add.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n"
9766
+ " %28 = torch.aten.gt.int %27, %int0 : !torch.int, !torch.int -> !torch.bool\n"
9767
+ " %29 = torch.aten.Int.bool %28 : !torch.bool -> !torch.int\n"
9768
+ " torch.prim.If.yield %29 : !torch.int\n"
9769
+ " }\n"
9770
+ " %13 = torch.aten.add.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n"
9771
+ " %14 = torch.prim.min.int %arg1, %13 : !torch.int, !torch.int -> !torch.int\n"
9772
+ " %15 = torch.prim.max.int %int0, %14 : !torch.int, !torch.int -> !torch.int\n"
9773
+ " %16 = torch.aten.add.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n"
9774
+ " %17 = torch.prim.min.int %arg0, %16 : !torch.int, !torch.int -> !torch.int\n"
9775
+ " %18 = torch.prim.max.int %int0, %17 : !torch.int, !torch.int -> !torch.int\n"
9776
+ " %19 = torch.aten.sub.int %15, %12 : !torch.int, !torch.int -> !torch.int\n"
9777
+ " %20 = torch.aten.add.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n"
9778
+ " %21 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int\n"
9779
+ " %22 = torch.aten.mul.int %21, %20 : !torch.int, !torch.int -> !torch.int\n"
9780
+ " %23 = torch.aten.floordiv.int %22, %int2 : !torch.int, !torch.int -> !torch.int\n"
9781
+ " %24 = torch.aten.sub.int %18, %20 : !torch.int, !torch.int -> !torch.int\n"
9782
+ " %25 = torch.aten.mul.int %24, %arg1 : !torch.int, !torch.int -> !torch.int\n"
9783
+ " %26 = torch.prim.max.int %int0, %25 : !torch.int, !torch.int -> !torch.int\n"
9784
+ " torch.prim.If.yield %23, %26 : !torch.int, !torch.int\n"
9785
+ " }\n"
9786
+ " %7 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int\n"
9787
+ " %8 = torch.aten.add.int %6#0, %6#1 : !torch.int, !torch.int -> !torch.int\n"
9788
+ " %9 = torch.aten.sub.int %7, %8 : !torch.int, !torch.int -> !torch.int\n"
9789
+ " %10 = torch.prim.ListConstruct %int2, %9 : (!torch.int, !torch.int) -> !torch.list<int>\n"
9790
+ " torch.prim.If.yield %10 : !torch.list<int>\n"
9791
+ " }\n"
9792
+ " return %2 : !torch.list<int>\n"
9793
+ " }\n"
9732
9794
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
9733
9795
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
9734
9796
" return %0 : !torch.tuple<list<int>, list<int>>\n"
@@ -14023,6 +14085,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
14023
14085
" %int6 = torch.constant.int 6\n"
14024
14086
" return %int6 : !torch.int\n"
14025
14087
" }\n"
14088
+ " func.func @\"__torch_mlir_dtype_fn.aten.triu_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
14089
+ " %int4 = torch.constant.int 4\n"
14090
+ " %none = torch.constant.none\n"
14091
+ " %0 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
14092
+ " %1 = torch.prim.If %0 -> (!torch.int) {\n"
14093
+ " torch.prim.If.yield %int4 : !torch.int\n"
14094
+ " } else {\n"
14095
+ " %2 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
14096
+ " torch.prim.If.yield %2 : !torch.int\n"
14097
+ " }\n"
14098
+ " return %1 : !torch.int\n"
14099
+ " }\n"
14026
14100
" func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
14027
14101
" %int3 = torch.constant.int 3\n"
14028
14102
" %int1 = torch.constant.int 1\n"
0 commit comments