Skip to content

Commit 27e8eb2

Browse files
zjgarveymgehre-amd
authored andcommitted
[ONNX] Fix resize ceil numerics and add half_pixel_symmetric support (llvm#3443)
This patch fixes several failing tests in our [external test suite](https://github.com/nod-ai/SHARK-TestSuite/tree/main/iree_tests/onnx/node/generated), and addresses some of the issues discussed in llvm#3420
1 parent 2756231 commit 27e8eb2

File tree

2 files changed

+104
-2
lines changed

2 files changed

+104
-2
lines changed

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2648,14 +2648,21 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
26482648
nearestFP = b.create<arith::SelectOp>(loc, cmp, floor, ceil);
26492649
} else if (nearestMode == "round_prefer_ceil") {
26502650
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
2651+
Value cstOne = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1));
26512652
Value floor = b.create<math::FloorOp>(loc, proj);
26522653
Value ceil = b.create<math::CeilOp>(loc, proj);
26532654
Value decimal = b.create<arith::SubFOp>(loc, proj, floor);
26542655
Value cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE,
26552656
decimal, cstHalf);
26562657
nearestFP = b.create<arith::SelectOp>(loc, cmp, ceil, floor);
2658+
Value inputSizeMOne = b.create<arith::SubFOp>(loc, inputSizeFP, cstOne);
2659+
// don't extract out of bounds
2660+
nearestFP = b.create<arith::MinimumFOp>(loc, nearestFP, inputSizeMOne);
26572661
} else if (nearestMode == "ceil") {
2662+
Value cstOne = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1));
2663+
Value inputSizeMOne = b.create<arith::SubFOp>(loc, inputSizeFP, cstOne);
26582664
nearestFP = b.create<math::CeilOp>(loc, proj);
2665+
nearestFP = b.create<arith::MinimumFOp>(loc, nearestFP, inputSizeMOne);
26592666
} else {
26602667
llvm_unreachable("Unsupported nearest mode");
26612668
}
@@ -2729,7 +2736,8 @@ static Value BilinearInterpolate(OpBuilder &b,
27292736
if (coordStr == "_asymmetric") {
27302737
preClip = b.create<arith::DivFOp>(loc, outFP, scale);
27312738
}
2732-
if (coordStr == "_pytorch_half_pixel" || coordStr == "") {
2739+
if (coordStr == "_pytorch_half_pixel" || coordStr == "" ||
2740+
coordStr == "_half_pixel_symmetric") {
27332741
// half-pixel modes
27342742
// y_resized + 0.5
27352743
Value outPlusHalf = b.create<arith::AddFOp>(loc, outFP, cstHalf);
@@ -2738,6 +2746,18 @@ static Value BilinearInterpolate(OpBuilder &b,
27382746
// _ - 0.5
27392747
preClip = b.create<arith::SubFOp>(loc, outDivScale, cstHalf);
27402748
}
2749+
// for half_pixel_symmetric, need to compute offset from raw scales
2750+
if (coordStr == "_half_pixel_symmetric" && !scaleValues.empty()) {
2751+
Value outputSizeFromScale = b.create<arith::MulFOp>(loc, inputFP, scale);
2752+
Value adjustment =
2753+
b.create<arith::DivFOp>(loc, outputSizeFP, outputSizeFromScale);
2754+
Value cstTwo = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(2.0));
2755+
Value center = b.create<arith::DivFOp>(loc, inputFP, cstTwo);
2756+
Value oneMAdjustment =
2757+
b.create<arith::SubFOp>(loc, cstOneFloat, adjustment);
2758+
Value offset = b.create<arith::MulFOp>(loc, center, oneMAdjustment);
2759+
preClip = b.create<arith::AddFOp>(loc, offset, preClip);
2760+
}
27412761
// for pytorch half pixel , special case for length_resized == 1:
27422762
if (coordStr == "_pytorch_half_pixel") {
27432763
Value cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,

test/Conversion/TorchToLinalg/resize.mlir

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,89 @@ func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1:
156156
return %7 : !torch.vtensor<[?,?,?,?,?],f32>
157157
}
158158

159-
// CHECK-LABEL: func.func @test_resize_nearest_half_pixel
159+
// -----
160+
161+
// CHECK-LABEL: func.func @test_resize_nearest_ceil
162+
func.func @test_resize_nearest_ceil(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> {
163+
// CHECK: %[[GENERIC:.*]] = linalg.generic
164+
// CHECK: %[[x11:.*]] = linalg.index 0 : index
165+
// CHECK: %[[x12:.*]] = linalg.index 1 : index
166+
// CHECK: %[[x13:.*]] = linalg.index 2 : index
167+
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
168+
// CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
169+
// CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32
170+
// CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64
171+
// CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32
172+
// CHECK: %[[cst:.*]] = arith.constant 5.000000e-01 : f32
173+
// CHECK: %[[add:.*]] = arith.addf %[[x24]], %[[cst]] : f32
174+
// CHECK: %[[x25:.*]] = arith.divf %[[add]], %[[x21]] : f32
175+
// CHECK: %[[sub:.*]] = arith.subf %[[x25]], %[[cst]] : f32
176+
// CHECK: %[[cst3:.*]] = arith.constant 1.000000e+00 : f32
177+
// CHECK: %[[nM1:.*]] = arith.subf %[[inputsizefp:.*]], %[[cst3]]
178+
// CHECK: %[[ceil:.*]] = math.ceil %[[sub]] : f32
179+
// CHECK: %[[minindex:.*]] = arith.minimumf %[[ceil]], %[[nM1]]
180+
// CHECK: %[[x31:.*]] = arith.fptosi %[[minindex]] : f32 to i64
181+
// CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index
182+
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor<?x?x?xf32>
183+
// CHECK: linalg.yield %[[extracted]] : f32
184+
%none = torch.constant.none
185+
%none_0 = torch.constant.none
186+
%int0 = torch.constant.int 0
187+
%false = torch.constant.bool false
188+
%true = torch.constant.bool true
189+
%str = torch.constant.str "nearest_half_pixel,ceil"
190+
%int2 = torch.constant.int 2
191+
%0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
192+
%1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
193+
%4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list<int>
194+
%5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32>
195+
return %5 : !torch.vtensor<[?,?,?],f32>
196+
}
197+
198+
// -----
199+
200+
// CHECK-LABEL: func.func @test_resize_scales_linear_half_pixel_symmetric
201+
func.func @test_resize_scales_linear_half_pixel_symmetric(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
202+
,f64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
203+
// CHECK: %[[generic:.*]] = linalg.generic
204+
// CHECK: %[[cst7:.*]] = arith.constant 2.0
205+
// CHECK: %[[halfsize:.*]] = arith.divf %[[sizefp:.*]], %[[cst7]]
206+
// CHECK: %[[modifier:.*]] = arith.subf %[[cstOne:.*]], %[[adjustment:.*]]
207+
// CHECK: %[[offset:.*]] = arith.mulf %[[halfsize]], %[[modifier]]
208+
// CHECK: %[[preClip:.*]] = arith.addf %[[offset]], %[[halfpixelbase:.*]]
209+
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x1:.*]], %[[x2:.*]], %[[x3:.*]], %[[x4:.*]]] : tensor<1x1x2x4xf32>
210+
// CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
211+
// CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
212+
// CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
213+
// CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]]
214+
// CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]]
215+
// CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]]
216+
// CHECK: %[[left:.*]] = arith.mulf %[[dy0:.*]], %[[sum]]
217+
// CHECK: %[[dx0p10:.*]] = arith.mulf %[[dx0]], %[[extracted_8]]
218+
// CHECK: %[[dx1p11:.*]] = arith.mulf %[[dx1]], %[[extracted_9]]
219+
// CHECK: %[[sum2:.*]] = arith.addf %[[dx0p10]], %[[dx1p11]]
220+
// CHECK: %[[right:.*]] = arith.mulf %[[dy1:.*]], %[[sum2]]
221+
// CHECK: %[[retval:.*]] = arith.addf %[[left]], %[[right]]
222+
%none = torch.constant.none
223+
%none_0 = torch.constant.none
224+
%int0 = torch.constant.int 0
225+
%false = torch.constant.bool false
226+
%true = torch.constant.bool true
227+
%str = torch.constant.str "bilinear_half_pixel_symmetric"
228+
%int2 = torch.constant.int 2
229+
%0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],f64>, !torch.int, !torch.int -> !torch.vtensor<[1],f64>
230+
%1 = torch.aten.item %0 : !torch.vtensor<[1],f64> -> !torch.float
231+
%int3 = torch.constant.int 3
232+
%2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],f64>, !torch.int, !torch.int -> !torch.vtensor<[1],f64>
233+
%3 = torch.aten.item %2 : !torch.vtensor<[1],f64> -> !torch.float
234+
%4 = torch.prim.ListConstruct %1, %3 : (!torch.float, !torch.float) -> !torch.list<float>
235+
%5 = torch.aten.__interpolate.size_list_scale_list %arg0, %none_0, %4, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.list<float>, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
236+
return %5 : !torch.vtensor<[?,?,?,?],f32>
237+
}
238+
239+
// -----
240+
241+
// CHECK-LABEL: func.func @test_resize_nearest_half_pixel_round_prefer_floor
160242
func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> {
161243
// CHECK: %[[GENERIC:.*]] = linalg.generic
162244
// CHECK: %[[x11:.*]] = linalg.index 0 : index

0 commit comments

Comments
 (0)