@@ -261,15 +261,16 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar
261
261
262
262
// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices
263
263
func.func @test_scatter_elements_with_duplicate_indices (%arg0: !torch.vtensor <[1 ,5 ],f32 >, %arg1: !torch.vtensor <[1 ,2 ],si64 >, %arg2: !torch.vtensor <[1 ,2 ],f32 >) -> !torch.vtensor <[1 ,5 ],f32 > attributes {torch.onnx_meta.ir_version = 8 : si64 , torch.onnx_meta.opset_version = 18 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
264
- // CHECK: %[[AXIS:.*]] = torch.constant.int 1
265
- // CHECK: %[[ZERO:.+]] = torch.constant.int 0
266
- // CHECK: %[[ONE:.+]] = torch.constant.int 1
267
- // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
268
- // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
269
- // CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
270
- // CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
271
- // CHECK: %[[STR:.*]] = torch.constant.str "add"
272
- // CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
264
+ // CHECK: %[[AXIS:.*]] = torch.constant.int 1
265
+ // CHECK: %[[ZERO:.*]] = torch.constant.int 0
266
+ // CHECK: %[[FIVE:.*]] = torch.constant.int 1
267
+ // CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
268
+ // CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
269
+ // CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
270
+ // CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
271
+ // CHECK: %[[STR:.*]] = torch.constant.str "sum"
272
+ // CHECK: %[[TRUE:.*]] = torch.constant.bool true
273
+ // CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
273
274
%0 = torch.operator " onnx.ScatterElements" (%arg0 , %arg1 , %arg2 ) {torch.onnx.axis = 1 : si64 , torch.onnx.reduction = " add" } : (!torch.vtensor <[1 ,5 ],f32 >, !torch.vtensor <[1 ,2 ],si64 >, !torch.vtensor <[1 ,2 ],f32 >) -> !torch.vtensor <[1 ,5 ],f32 >
274
275
return %0 : !torch.vtensor <[1 ,5 ],f32 >
275
276
}
@@ -294,15 +295,16 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>,
294
295
295
296
// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul
296
297
func.func @test_scatter_elements_with_reduction_mul (%arg0: !torch.vtensor <[1 ,5 ],f32 >, %arg1: !torch.vtensor <[1 ,2 ],si64 >, %arg2: !torch.vtensor <[1 ,2 ],f32 >) -> !torch.vtensor <[1 ,5 ],f32 > attributes {torch.onnx_meta.ir_version = 8 : si64 , torch.onnx_meta.opset_version = 18 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
297
- // CHECK: %[[AXIS:.*]] = torch.constant.int 1
298
- // CHECK: %[[ZERO:.+]] = torch.constant.int 0
299
- // CHECK: %[[ONE:.+]] = torch.constant.int 1
300
- // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]]
301
- // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]]
302
- // CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]]
303
- // CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1
304
- // CHECK: %[[STR:.*]] = torch.constant.str "multiply"
305
- // CHECK: torch.aten.scatter.reduce %arg0, %[[AXIS]], %[[WHERE]], %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32>
298
+ // CHECK: %[[AXIS:.*]] = torch.constant.int 1
299
+ // CHECK: %[[ZERO:.*]] = torch.constant.int 0
300
+ // CHECK: %[[FIVE:.*]] = torch.constant.int 1
301
+ // CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int
302
+ // CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64>
303
+ // CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1>
304
+ // CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64>
305
+ // CHECK: %[[STR:.*]] = torch.constant.str "prod"
306
+ // CHECK: %[[TRUE:.*]] = torch.constant.bool true
307
+ // CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32>
306
308
%0 = torch.operator " onnx.ScatterElements" (%arg0 , %arg1 , %arg2 ) {torch.onnx.axis = 1 : si64 , torch.onnx.reduction = " mul" } : (!torch.vtensor <[1 ,5 ],f32 >, !torch.vtensor <[1 ,2 ],si64 >, !torch.vtensor <[1 ,2 ],f32 >) -> !torch.vtensor <[1 ,5 ],f32 >
307
309
return %0 : !torch.vtensor <[1 ,5 ],f32 >
308
310
}
0 commit comments