Skip to content

Commit f4840ed

Browse files
authored
[ONNX] Fix onnx.ScatterElements with AtenScatterReduceTwoOp lowering to tm_tensor/linalg_ext dialect (#3754)
- To fix issue onnx.ScatterElements: nod-ai/SHARK-ModelDev#823 - E2E test: nod-ai/SHARK-TestSuite#363
1 parent 53f7532 commit f4840ed

File tree

3 files changed

+27
-23
lines changed

3 files changed

+27
-23
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -635,18 +635,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
635635

636636
// TODO: Implement max and min cases
637637
if (reduction == "mul") {
638-
reduction = "multiply";
638+
reduction = "prod";
639639
} else if (reduction == "max" || reduction == "min") {
640640
return rewriter.notifyMatchFailure(
641641
binder.op, "max/min reduction unsupported for scatter elements");
642+
} else if (reduction == "add") {
643+
reduction = "sum";
642644
}
643645

644646
Value cstStrReduction =
645647
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), reduction);
646-
647-
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceOp>(
648+
Value cstTrue =
649+
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
650+
rewriter.replaceOpWithNewOp<Torch::AtenScatterReduceTwoOp>(
648651
binder.op, resultType, data, constAxis, indices, updates,
649-
cstStrReduction);
652+
cstStrReduction, cstTrue);
650653
return success();
651654
});
652655
patterns.onOp(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3084,7 +3084,6 @@
30843084
"ScatterReduceIntMaxModuleIncludeSelf",
30853085
"ScatterReduceIntMinModuleIncludeSelf",
30863086
"ScatterValueFloatModule_basic",
3087-
"ScatterAddStaticModule_basic",
30883087
# Failure - onnx_lowering: onnx.ScatterND
30893088
"IndexPut1DFloatAccumulateModule_basic",
30903089
"IndexPut1DIntAccumulateModule_basic",

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -261,15 +261,16 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar
261261

262262
// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices
263263
func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
264-
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
265-
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
266-
// CHECK: %[[ONE:.+]] = torch.constant.int 1
267-
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
268-
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
269-
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
270-
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
271-
// CHECK: %[[STR:.*]] = torch.constant.str "add"
272-
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
264+
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
265+
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
266+
// CHECK: %[[FIVE:.*]] = torch.constant.int 1
267+
// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
268+
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
269+
// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
270+
// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
271+
// CHECK: %[[STR:.*]] = torch.constant.str "sum"
272+
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
273+
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
273274
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
274275
return %0 : !torch.vtensor<[1,5],f32>
275276
}
@@ -294,15 +295,16 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>,
294295

295296
// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul
296297
func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
297-
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
298-
// CHECK: %[[ZERO:.+]] = torch.constant.int 0
299-
// CHECK: %[[ONE:.+]] = torch.constant.int 1
300-
// CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
301-
// CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
302-
// CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
303-
// CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
304-
// CHECK: %[[STR:.*]] = torch.constant.str "multiply"
305-
// CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
298+
// CHECK: %[[AXIS:.*]] = torch.constant.int 1
299+
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
300+
// CHECK: %[[FIVE:.*]] = torch.constant.int 1
301+
// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
302+
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
303+
// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
304+
// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
305+
// CHECK: %[[STR:.*]] = torch.constant.str "prod"
306+
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
307+
// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
306308
%0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32>
307309
return %0 : !torch.vtensor<[1,5],f32>
308310
}

0 commit comments

Comments
 (0)