@@ -156,7 +156,89 @@ func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1:
156
156
return %7 : !torch.vtensor <[?,?,?,?,?],f32 >
157
157
}
158
158
159
- // CHECK-LABEL: func.func @test_resize_nearest_half_pixel
159
+ // -----
160
+
161
+ // CHECK-LABEL: func.func @test_resize_nearest_ceil
162
+ func.func @test_resize_nearest_ceil (%arg0: !torch.vtensor <[?,?,?],f32 >, %arg1: !torch.vtensor <[3 ],si64 >) -> !torch.vtensor <[?,?,?],f32 > {
163
+ // CHECK: %[[GENERIC:.*]] = linalg.generic
164
+ // CHECK: %[[x11:.*]] = linalg.index 0 : index
165
+ // CHECK: %[[x12:.*]] = linalg.index 1 : index
166
+ // CHECK: %[[x13:.*]] = linalg.index 2 : index
167
+ // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
168
+ // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
169
+ // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32
170
+ // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64
171
+ // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32
172
+ // CHECK: %[[cst:.*]] = arith.constant 5.000000e-01 : f32
173
+ // CHECK: %[[add:.*]] = arith.addf %[[x24]], %[[cst]] : f32
174
+ // CHECK: %[[x25:.*]] = arith.divf %[[add]], %[[x21]] : f32
175
+ // CHECK: %[[sub:.*]] = arith.subf %[[x25]], %[[cst]] : f32
176
+ // CHECK: %[[cst3:.*]] = arith.constant 1.000000e+00 : f32
177
+ // CHECK: %[[nM1:.*]] = arith.subf %[[inputsizefp:.*]], %[[cst3]]
178
+ // CHECK: %[[ceil:.*]] = math.ceil %[[sub]] : f32
179
+ // CHECK: %[[minindex:.*]] = arith.minimumf %[[ceil]], %[[nM1]]
180
+ // CHECK: %[[x31:.*]] = arith.fptosi %[[minindex]] : f32 to i64
181
+ // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index
182
+ // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor<?x?x?xf32>
183
+ // CHECK: linalg.yield %[[extracted]] : f32
184
+ %none = torch.constant.none
185
+ %none_0 = torch.constant.none
186
+ %int0 = torch.constant.int 0
187
+ %false = torch.constant.bool false
188
+ %true = torch.constant.bool true
189
+ %str = torch.constant.str " nearest_half_pixel,ceil"
190
+ %int2 = torch.constant.int 2
191
+ %0 = torch.aten.select.int %arg1 , %int0 , %int2 : !torch.vtensor <[3 ],si64 >, !torch.int , !torch.int -> !torch.vtensor <[1 ],si64 >
192
+ %1 = torch.aten.item %0 : !torch.vtensor <[1 ],si64 > -> !torch.int
193
+ %4 = torch.prim.ListConstruct %1 : (!torch.int ) -> !torch.list <int >
194
+ %5 = torch.aten.__interpolate.size_list_scale_list %arg0 , %4 , %none_0 , %str , %false , %none_0 , %false : !torch.vtensor <[?,?,?],f32 >, !torch.list <int >, !torch.none , !torch.str , !torch.bool , !torch.none , !torch.bool -> !torch.vtensor <[?,?,?],f32 >
195
+ return %5 : !torch.vtensor <[?,?,?],f32 >
196
+ }
197
+
198
+ // -----
199
+
200
+ // CHECK-LABEL: func.func @test_resize_scales_linear_half_pixel_symmetric
201
+ func.func @test_resize_scales_linear_half_pixel_symmetric (%arg0: !torch.vtensor <[1 ,1 ,2 ,4 ],f32 >, %arg1: !torch.vtensor <[4 ]
202
+ ,f64 >) -> !torch.vtensor <[?,?,?,?],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 19 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
203
+ // CHECK: %[[generic:.*]] = linalg.generic
204
+ // CHECK: %[[cst7:.*]] = arith.constant 2.0
205
+ // CHECK: %[[halfsize:.*]] = arith.divf %[[sizefp:.*]], %[[cst7]]
206
+ // CHECK: %[[modifier:.*]] = arith.subf %[[cstOne:.*]], %[[adjustment:.*]]
207
+ // CHECK: %[[offset:.*]] = arith.mulf %[[halfsize]], %[[modifier]]
208
+ // CHECK: %[[preClip:.*]] = arith.addf %[[offset]], %[[halfpixelbase:.*]]
209
+ // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x1:.*]], %[[x2:.*]], %[[x3:.*]], %[[x4:.*]]] : tensor<1x1x2x4xf32>
210
+ // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
211
+ // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
212
+ // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
213
+ // CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]]
214
+ // CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]]
215
+ // CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]]
216
+ // CHECK: %[[left:.*]] = arith.mulf %[[dy0:.*]], %[[sum]]
217
+ // CHECK: %[[dx0p10:.*]] = arith.mulf %[[dx0]], %[[extracted_8]]
218
+ // CHECK: %[[dx1p11:.*]] = arith.mulf %[[dx1]], %[[extracted_9]]
219
+ // CHECK: %[[sum2:.*]] = arith.addf %[[dx0p10]], %[[dx1p11]]
220
+ // CHECK: %[[right:.*]] = arith.mulf %[[dy1:.*]], %[[sum2]]
221
+ // CHECK: %[[retval:.*]] = arith.addf %[[left]], %[[right]]
222
+ %none = torch.constant.none
223
+ %none_0 = torch.constant.none
224
+ %int0 = torch.constant.int 0
225
+ %false = torch.constant.bool false
226
+ %true = torch.constant.bool true
227
+ %str = torch.constant.str " bilinear_half_pixel_symmetric"
228
+ %int2 = torch.constant.int 2
229
+ %0 = torch.aten.select.int %arg1 , %int0 , %int2 : !torch.vtensor <[4 ],f64 >, !torch.int , !torch.int -> !torch.vtensor <[1 ],f64 >
230
+ %1 = torch.aten.item %0 : !torch.vtensor <[1 ],f64 > -> !torch.float
231
+ %int3 = torch.constant.int 3
232
+ %2 = torch.aten.select.int %arg1 , %int0 , %int3 : !torch.vtensor <[4 ],f64 >, !torch.int , !torch.int -> !torch.vtensor <[1 ],f64 >
233
+ %3 = torch.aten.item %2 : !torch.vtensor <[1 ],f64 > -> !torch.float
234
+ %4 = torch.prim.ListConstruct %1 , %3 : (!torch.float , !torch.float ) -> !torch.list <float >
235
+ %5 = torch.aten.__interpolate.size_list_scale_list %arg0 , %none_0 , %4 , %str , %false , %none_0 , %false : !torch.vtensor <[1 ,1 ,2 ,4 ],f32 >, !torch.none , !torch.list <float >, !torch.str , !torch.bool , !torch.none , !torch.bool -> !torch.vtensor <[?,?,?,?],f32 >
236
+ return %5 : !torch.vtensor <[?,?,?,?],f32 >
237
+ }
238
+
239
+ // -----
240
+
241
+ // CHECK-LABEL: func.func @test_resize_nearest_half_pixel_round_prefer_floor
160
242
func.func @test_resize_nearest_half_pixel_round_prefer_floor (%arg0: !torch.vtensor <[?,?,?],f32 >, %arg1: !torch.vtensor <[3 ],si64 >) -> !torch.vtensor <[?,?,?],f32 > {
161
243
// CHECK: %[[GENERIC:.*]] = linalg.generic
162
244
// CHECK: %[[x11:.*]] = linalg.index 0 : index
0 commit comments