@@ -10024,10 +10024,65 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
10024
10024
" %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
10025
10025
" return %0 : !torch.list<int>\n"
10026
10026
" }\n"
10027
+ " func.func @\"__torch_mlir_shape_fn.aten.conv2d.padding\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
10028
+ " %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.str) -> !torch.list<int>\n"
10029
+ " %1 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
10030
+ " return %1 : !torch.list<int>\n"
10031
+ " }\n"
10032
+ " func.func @__torch__._conv_padding(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.str) -> !torch.list<int> {\n"
10033
+ " %true = torch.constant.bool true\n"
10034
+ " %int-1 = torch.constant.int -1\n"
10035
+ " %str = torch.constant.str \"same\"\n"
10036
+ " %none = torch.constant.none\n"
10037
+ " %str_0 = torch.constant.str \"AssertionError: conv: weight must be at least 3 dimensional.\"\n"
10038
+ " %int2 = torch.constant.int 2\n"
10039
+ " %int0 = torch.constant.int 0\n"
10040
+ " %int1 = torch.constant.int 1\n"
10041
+ " %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
10042
+ " %1 = torch.aten.gt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
10043
+ " torch.prim.If %1 -> () {\n"
10044
+ " torch.prim.If.yield\n"
10045
+ " } else {\n"
10046
+ " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
10047
+ " torch.prim.If.yield\n"
10048
+ " }\n"
10049
+ " %2 = torch.aten.sub.int %0, %int2 : !torch.int, !torch.int -> !torch.int\n"
10050
+ " %3 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>\n"
10051
+ " %4 = torch.aten.mul.left_t %3, %2 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
10052
+ " %5 = torch.aten.eq.str %arg2, %str : !torch.str, !torch.str -> !torch.bool\n"
10053
+ " torch.prim.If %5 -> () {\n"
10054
+ " %6 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n"
10055
+ " %7 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
10056
+ " %8 = torch.aten.__range_length %6, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
10057
+ " %9 = torch.prim.ListConstruct %7, %8 : (!torch.int, !torch.int) -> !torch.list<int>\n"
10058
+ " %10 = torch.prim.min.self_int %9 : !torch.list<int> -> !torch.int\n"
10059
+ " torch.prim.Loop %10, %true, init() {\n"
10060
+ " ^bb0(%arg3: !torch.int):\n"
10061
+ " %11 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
10062
+ " %12 = torch.aten.__derive_index %arg3, %6, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
10063
+ " %13 = torch.aten.add.int %int2, %12 : !torch.int, !torch.int -> !torch.int\n"
10064
+ " %14 = torch.aten.__getitem__.t %arg0, %13 : !torch.list<int>, !torch.int -> !torch.int\n"
10065
+ " %15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int\n"
10066
+ " %16 = torch.aten.mul.int %11, %15 : !torch.int, !torch.int -> !torch.int\n"
10067
+ " %17 = torch.aten.floordiv.int %16, %int2 : !torch.int, !torch.int -> !torch.int\n"
10068
+ " %18 = torch.aten._set_item.t %4, %12, %17 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
10069
+ " torch.prim.Loop.condition %true, iter()\n"
10070
+ " } : (!torch.int, !torch.bool) -> ()\n"
10071
+ " torch.prim.If.yield\n"
10072
+ " } else {\n"
10073
+ " torch.prim.If.yield\n"
10074
+ " }\n"
10075
+ " return %4 : !torch.list<int>\n"
10076
+ " }\n"
10027
10077
" func.func @\"__torch_mlir_shape_fn.aten.conv3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
10028
10078
" %0 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
10029
10079
" return %0 : !torch.list<int>\n"
10030
10080
" }\n"
10081
+ " func.func @\"__torch_mlir_shape_fn.aten.conv3d.padding\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
10082
+ " %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.str) -> !torch.list<int>\n"
10083
+ " %1 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
10084
+ " return %1 : !torch.list<int>\n"
10085
+ " }\n"
10031
10086
" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose2d.input\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
10032
10087
" %0 = torch.derefine %arg3 : !torch.list<int> to !torch.optional<list<int>>\n"
10033
10088
" %1 = torch.derefine %arg4 : !torch.list<int> to !torch.optional<list<int>>\n"
@@ -10097,6 +10152,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
10097
10152
" %1 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %false, %0, %int1) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
10098
10153
" return %1 : !torch.list<int>\n"
10099
10154
" }\n"
10155
+ " func.func @\"__torch_mlir_shape_fn.aten.conv1d.padding\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
10156
+ " %false = torch.constant.bool false\n"
10157
+ " %int1 = torch.constant.int 1\n"
10158
+ " %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.str) -> !torch.list<int>\n"
10159
+ " %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
10160
+ " %2 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %false, %1, %int1) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
10161
+ " return %2 : !torch.list<int>\n"
10162
+ " }\n"
10100
10163
" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
10101
10164
" %true = torch.constant.bool true\n"
10102
10165
" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
0 commit comments