diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 3db33aee1f1c..ad0b2b8cd500 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -3701,63 +3701,53 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure( binder.op, "expected center_point_box attribute to be 0 or 1"); - // TODO: Support multiple batches and classes - // Squeeze the boxes and scores tensor. - // In Onnx, the shape of boxes is [BxNx4] while the - // torchvision expects it to be of shape [Nx4]. Similarly, for - // the scores tensor shape in Onnx is [BxCxN] while the - // torchvision expects it to be of shape [N]. + Value cst0 = rewriter.create(loc, 0); + Value cst1 = rewriter.create(loc, 1); + Value cst2 = rewriter.create(loc, 2); + Value cst3 = rewriter.create(loc, 3); + Value cst4 = rewriter.create(loc, 4); + Value cst2F = rewriter.create( + loc, rewriter.getF64FloatAttr(2.0)); + Value cstNone = rewriter.create(loc); + Value cstTrue = rewriter.create( + loc, rewriter.getBoolAttr(true)); + Value cstFalse = rewriter.create( + loc, rewriter.getBoolAttr(false)); + + // In Onnx, the shape of boxes is [BxNx4] and that of scores is [BxCxN] Value boxes = operands[0], scores = operands[1]; - FailureOr squeezedBoxes = - Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes); - if (failed(squeezedBoxes)) - return rewriter.notifyMatchFailure(binder.op, - "failed to squeeze boxes tensor"); - FailureOr squeezedScores = - Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores); - if (failed(squeezedScores)) - return rewriter.notifyMatchFailure(binder.op, - "failed to squeeze scores tensor"); - squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0, - squeezedScores.value()); - if (failed(squeezedScores)) - return rewriter.notifyMatchFailure(binder.op, - "failed to squeeze scores tensor"); - boxes = squeezedBoxes.value(); - scores = squeezedScores.value(); + + auto boxesTensorType = cast(boxes.getType()); + auto scoreTensorType = cast(scores.getType()); + auto boxSlicedType = rewriter.getType( + boxesTensorType.getSizes().slice(1), boxesTensorType.getDtype()); + auto scoreSlicedType = rewriter.getType( + scoreTensorType.getSizes().slice(1), scoreTensorType.getDtype()); + if (centerPointBox == 1) { // When center_point_box is 1, the box data is supplied as // [[x_center, y_center, width, height], ...]. Slice it to // [[x_center, y_center], ...] and [[width, height], ...], // calculate the [[x1, y1], ...] and [[x2, y2], ...], and concatnate // to [[x1, y1, x2, y2], ...] - auto boxesTensorType = - dyn_cast(boxes.getType()); - Value const0 = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value const1 = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value const2 = rewriter.create( - loc, rewriter.getI64IntegerAttr(2)); - Value const4 = rewriter.create( - loc, rewriter.getI64IntegerAttr(4)); - Value const2F = rewriter.create( - loc, rewriter.getF64FloatAttr(2.0)); // extract scaled ranges for regions of interest - auto sliceShape = SmallVector{Torch::kUnknownSize, 2}; + auto sliceShape = + SmallVector{Torch::kUnknownSize, Torch::kUnknownSize, 2}; auto sliceTensorType = rewriter.getType( sliceShape, boxesTensorType.getDtype()); + + // Boxes have shape [BxNx4] Value centers = rewriter.create( - loc, sliceTensorType, boxes, const1, const0, const2, const1); + loc, sliceTensorType, boxes, cst2, cst0, cst2, cst1); Value sizes = rewriter.create( - loc, sliceTensorType, boxes, const1, const2, const4, const1); + loc, sliceTensorType, boxes, cst2, cst2, cst4, cst1); Value halfSizes = rewriter.create( - loc, sizes.getType(), sizes, const2F); + loc, sizes.getType(), sizes, cst2F); Value x1y1s = rewriter.create( - loc, centers.getType(), centers, halfSizes, const1); + loc, centers.getType(), centers, halfSizes, cst1); Value x2y2s = rewriter.create( - loc, centers.getType(), centers, halfSizes, const1); + loc, centers.getType(), centers, halfSizes, cst1); Type listElemType = boxesTensorType.getWithSizesAndDtype( /*optionalSizes=*/std::nullopt, @@ -3766,7 +3756,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value tensorList = rewriter.create( loc, listType, SmallVector{x1y1s, x2y2s}); boxes = rewriter.create(loc, boxesTensorType, - tensorList, const1); + tensorList, cst2); } // TODO: Support score_threshold input @@ -3792,10 +3782,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } // Get max_output_boxes_per_class and iou_threshold - Value cst0 = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value cst1 = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); Value maxOutputBoxesPerClass = cst0; Value iouThreshold = rewriter.create( loc, rewriter.getF64FloatAttr(0.0)); @@ -3810,87 +3796,207 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( loc, rewriter.getType(), operands[2]); } + // Since the shape of boxes is [BxNx4] in Onnx and torchvision expects + // it to be of shape [Nx4], loop over the batch dimension. Similarly, + // for the scores tensor which has shape [BxCxN] in Onnx and torchvision + // expects it to be of shape [N], loop over the class dimension too. + auto numBatches = + rewriter.create(loc, scores, cst0); + auto numClasses = + rewriter.create(loc, scores, cst1); + + // Create an empty tensor of shape (B*C*N, 3) to store the final result. + // We slice this to required elements at the end + + Value numResults = rewriter.create( + loc, numClasses.getType(), numBatches, numClasses); + numResults = rewriter.create( + loc, numClasses.getType(), numResults, maxOutputBoxesPerClass); + + auto intTy = rewriter.getType(); + auto intListTy = rewriter.getType(intTy); + + Value resultShapeList = rewriter.create( + loc, intListTy, SmallVector{numResults, cst3}); + Value finalResult = rewriter.create( + loc, resultType, resultShapeList, /*dtype=*/cst4, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + auto nmsTy = Torch::ValueTensorType::get( binder.op->getContext(), SmallVector{-1}, rewriter.getIntegerType(64, /*signed=*/true)); - Value result = rewriter.create( - loc, nmsTy, boxes, scores, iouThreshold); - // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class - Value numOutputBoxes = - rewriter.create(loc, result, cst0); - Value boxesCond = rewriter.create( - loc, numOutputBoxes, maxOutputBoxesPerClass); + auto emptyTensorTy = rewriter.getType( + SmallVector{}, nmsTy.getDtype()); - auto nmsResultTy = Torch::ValueTensorType::get( - binder.op->getContext(), - SmallVector{resultType.getSizes()[0]}, - rewriter.getIntegerType(64, /*signed=*/true)); - auto ifSlice = rewriter.create( - loc, TypeRange({nmsResultTy}), boxesCond); + auto nmsBatchLoop = rewriter.create( + loc, TypeRange({resultType, intTy, intTy}), numBatches, cstTrue, + ValueRange({finalResult, /*Index to finalResult*/ cst0, + /*Num values in result*/ cst0})); { + // Batch loop body PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifSlice.getThenRegion(), - ifSlice.getThenRegion().begin()); + Block *batchLoopBody = rewriter.createBlock( + &nmsBatchLoop.getRegion(), nmsBatchLoop.getRegion().begin(), + TypeRange({intTy, resultType, intTy, intTy}), + {loc, loc, loc, loc}); + + auto batchIV = batchLoopBody->getArgument(0); + auto currRes = batchLoopBody->getArgument(1); + auto finalResIdx = batchLoopBody->getArgument(2); + auto numResultValues = batchLoopBody->getArgument(3); + + auto boxValue = rewriter.create( + loc, boxSlicedType, boxes, cst0, batchIV); + auto batchValue = rewriter.create( + loc, emptyTensorTy, batchIV); + + auto scoreSelect = rewriter.create( + loc, scoreSlicedType, scores, cst0, batchIV); + auto scoreSelectType = + cast(scoreSelect.getType()); + auto scoreValueType = rewriter.getType( + scoreSelectType.getSizes().slice(1), scoreSelectType.getDtype()); + + auto nmsClassLoop = rewriter.create( + loc, TypeRange({resultType, intTy, intTy}), numClasses, cstTrue, + ValueRange({currRes, finalResIdx, numResultValues})); - Value curResult = rewriter.create( - loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0, - /*end=*/maxOutputBoxesPerClass, /*step=*/cst1); - rewriter.create(loc, curResult); - } - { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifSlice.getElseRegion(), - ifSlice.getElseRegion().begin()); - - Value curResult = rewriter.create( - loc, nmsResultTy, result); - rewriter.create(loc, curResult); - } - result = ifSlice.getResult(0); - - // The result generated by torchvision.nms op is of shape [n], while the - // onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor - // and make it of shape [n, 1] and then concatenate it with a zero - // tensor of shape [n, 2] to make it of shape [n, 3]. - FailureOr unsqueezedResult = - Torch::unsqueezeTensor(rewriter, binder.op, result, cst1); - if (failed(unsqueezedResult)) - return rewriter.notifyMatchFailure( - binder.op, "failed to unsqueeze result tensor"); - result = unsqueezedResult.value(); - - numOutputBoxes = - rewriter.create(loc, result, cst0); - SmallVector zerosShapeValues{numOutputBoxes}; - zerosShapeValues.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(2))); - Value zerosShapeList = rewriter.create( - loc, - rewriter.getType( - rewriter.getType()), - zerosShapeValues); - std::optional> resultShape = - cast(result.getType()).getOptionalSizes(); - if (!resultShape.has_value()) - return rewriter.notifyMatchFailure( - binder.op, "expected result tensor to have shape"); - llvm::SmallVector zerosShape = {resultShape->front(), 2}; - auto zerosTy = Torch::ValueTensorType::get( - resultType.getContext(), zerosShape, resultType.getOptionalDtype()); - Value cstNone = rewriter.create(loc); - Value zeros = rewriter.create( - loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone); - - Type listElemType = - cast(resultType) - .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, - /*optionalDtype=*/nullptr); - Type listType = Torch::ListType::get(listElemType); - Value tensorList = rewriter.create( - loc, listType, SmallVector{zeros, result}); - rewriter.replaceOpWithNewOp(binder.op, resultType, - tensorList, cst1); + { + // Class loop body + PatternRewriter::InsertionGuard guard(rewriter); + Block *classLoopBody = rewriter.createBlock( + &nmsClassLoop.getRegion(), nmsClassLoop.getRegion().begin(), + TypeRange({intTy, resultType, intTy, intTy}), + {loc, loc, loc, loc}); + + auto classIV = classLoopBody->getArgument(0); + auto currRes = classLoopBody->getArgument(1); + auto finalResIdx = classLoopBody->getArgument(2); + Value numResultValues = classLoopBody->getArgument(3); + + auto scoreValue = rewriter.create( + loc, scoreValueType, scoreSelect, cst0, classIV); + auto classValue = rewriter.create( + loc, emptyTensorTy, classIV); + + // TorchVision NMS + Value result = rewriter.create( + loc, nmsTy, boxValue, scoreValue, iouThreshold); + + // Compute NumOutputBoxes + Value numOutputBoxes = + rewriter.create(loc, result, cst0); + numOutputBoxes = rewriter.create( + loc, emptyTensorTy, numOutputBoxes); + Value maxBoxesPerClass = + rewriter.create( + loc, emptyTensorTy, maxOutputBoxesPerClass); + auto minVal = rewriter.create( + loc, emptyTensorTy, numOutputBoxes, maxBoxesPerClass); + numOutputBoxes = + rewriter.create(loc, intTy, minVal); + + // Loop through the nms result + // The resulting shape of torchvision nms op is [num_selected] while + // that of onnx is [num_selected, 3] where the selected format is + // [batch_index, class_index, box_index]. + // Insert the triplet [batch_index, class_index, box_index] into + // `finalResult` element by element for each box. + + // TODO:: This can be simplified by concatinating the result of nms + // with that of tensors filled with batch and class indices instead + // of using the below loop. Currently this approach results in + // failures while lowering due to dynamic dims + + auto nmsLoop = rewriter.create( + loc, TypeRange({resultType, intTy}), numOutputBoxes, cstTrue, + ValueRange({currRes, finalResIdx})); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *loopBody = rewriter.createBlock( + &nmsLoop.getRegion(), nmsLoop.getRegion().begin(), + TypeRange({intTy, resultType, intTy}), {loc, loc, loc}); + auto iter = loopBody->getArgument(0); + auto currRes = loopBody->getArgument(1); + auto idxCst = loopBody->getArgument(2); + + auto outputTensorSliceType = + rewriter.getType( + SmallVector{3}, nmsTy.getDtype()); + + // Update batch dimension + auto batchDim3D = rewriter.create( + loc, outputTensorSliceType, currRes, cst0, idxCst); + auto batchSelect = rewriter.create( + loc, emptyTensorTy, batchDim3D, cst0, cst0); + auto bCopy = rewriter.create( + loc, batchSelect.getType(), batchSelect, batchValue, + cstFalse); + batchDim3D = rewriter.create( + loc, outputTensorSliceType, currRes, cst0, idxCst); + auto scatterBatch = rewriter.create( + loc, outputTensorSliceType, batchDim3D, bCopy, cst0, cst0); + auto batchResult = rewriter.create( + loc, resultType, currRes, scatterBatch, cst0, idxCst); + + // Update class dimension + auto classDim3D = rewriter.create( + loc, outputTensorSliceType, batchResult, cst0, idxCst); + auto classSelect = rewriter.create( + loc, emptyTensorTy, classDim3D, cst0, cst1); + auto cCopy = rewriter.create( + loc, classSelect.getType(), classSelect, classValue, + cstFalse); + classDim3D = rewriter.create( + loc, outputTensorSliceType, batchResult, cst0, idxCst); + auto scatterClass = rewriter.create( + loc, outputTensorSliceType, classDim3D, cCopy, cst0, cst1); + auto classRes = rewriter.create( + loc, resultType, batchResult, scatterClass, cst0, idxCst); + + // Update nms result dimension + auto resDim3D = rewriter.create( + loc, outputTensorSliceType, classRes, cst0, idxCst); + auto resSelect = rewriter.create( + loc, emptyTensorTy, resDim3D, cst0, cst2); + auto nmsResultValue = rewriter.create( + loc, emptyTensorTy, result, cst0, iter); + auto rCopy = rewriter.create( + loc, resSelect.getType(), resSelect, nmsResultValue, + cstFalse); + resDim3D = rewriter.create( + loc, outputTensorSliceType, classRes, cst0, idxCst); + auto scatterRes = rewriter.create( + loc, outputTensorSliceType, resDim3D, rCopy, cst0, cst2); + Value nmsResult = rewriter.create( + loc, resultType, classRes, scatterRes, cst0, idxCst); + + // Increment the result index + Value next = + rewriter.create(loc, idxCst, cst1); + rewriter.create( + loc, cstTrue, ValueRange({nmsResult, next})); + } + // Update the num result values + numResultValues = rewriter.create( + loc, numResultValues, numOutputBoxes); + rewriter.create( + loc, cstTrue, + ValueRange({nmsLoop.getResult(0), nmsLoop.getResult(1), + numResultValues})); + } + rewriter.create( + loc, cstTrue, + ValueRange({nmsClassLoop.getResult(0), nmsClassLoop.getResult(1), + nmsClassLoop->getResult(2)})); + } + // Slice the result to required number of elements + rewriter.replaceOpWithNewOp( + binder.op, resultType, nmsBatchLoop.getResult(0), cst0, cst0, + nmsBatchLoop.getResult(2), cst1); return success(); }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index b2c718bceace..e4ec736c2a8e 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -2034,53 +2034,25 @@ func.func @test_loop_forlike(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtens // CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>, // CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4],f32>, %arg1: !torch.vtensor<[1,1,10],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[VAL_5:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,10,4],f32>, !torch.int -> !torch.int - // CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1." - // CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,10,4],f32>, !torch.int -> !torch.vtensor<[10,4],f32> - // CHECK: %[[VAL_10:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_11:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,10],f32>, !torch.int -> !torch.int - // CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1." - // CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,10],f32>, !torch.int -> !torch.vtensor<[1,10],f32> - // CHECK: %[[VAL_15:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_16:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.int - // CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1." - // CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.vtensor<[10],f32> - // CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" - // CHECK: %[[VAL_24:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_25:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00 - // CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[?],si64> - // CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>) - // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64> - // CHECK: } else { - // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> - // CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64> - // CHECK: } - // CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> - // CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_35:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_37:.*]] = torch.constant.none - // CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> - // CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list - // CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> - // CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64> +// CHECK: %[[RES:.*]] = torch.aten.empty.memory_format +// CHECK: %[[LOOP_BATCH:.*]]:3 = torch.prim.Loop +// CHECK: %[[LOOP_CLASS:.*]]:3 = torch.prim.Loop +// CHECK: %[[NMS:.*]] = torch.torchvision.nms +// CHECK: %[[MIN_RES:.*]] = torch.aten.minimum +// CHECK: %[[NUM_ELE:.*]] = torch.aten.item %[[MIN_RES]] : !torch.vtensor<[],si64> -> !torch.int +// CHECK: %[[LOOP_NMS:.*]]:2 = torch.prim.Loop %[[NUM_ELE]] +// CHECK: %[[SEL_1:.*]] = torch.aten.select.int +// CHECK: %[[SEL_2:.*]] = torch.aten.select.int +// CHECK: %[[COPY:.*]] = torch.aten.copy +// CHECK: %[[SEL_3:.*]] = torch.aten.select.int +// CHECK-COUNT-6: torch.aten.select_scatter +// CHECK: %[[ADD_INDEX:.*]] = torch.aten.add.int +// CHECK: torch.prim.Loop.condition +// CHECK: %[[ADD_INDEX_1:.*]] = torch.aten.add.int +// CHECK: torch.prim.Loop.condition +// CHECK: torch.prim.Loop.condition +// CHECK: %[[SLICE_RES:.*]] = torch.aten.slice.Tensor +// CHECK: return %[[SLICE_RES]] : !torch.vtensor<[1,3],si64> %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,10,4],f32>, !torch.vtensor<[1,1,10],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> return %0 : !torch.vtensor<[1,3],si64> } @@ -2094,53 +2066,25 @@ func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4] // CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>, // CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, %arg1: !torch.vtensor<[1,1,1],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[VAL_5:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.int - // CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1." - // CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.vtensor<[1,4],f32> - // CHECK: %[[VAL_10:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_11:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.int - // CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1." - // CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> - // CHECK: %[[VAL_15:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_16:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.int - // CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1." - // CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32> - // CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" - // CHECK: %[[VAL_24:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_25:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00 - // CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64> - // CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>) - // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64> - // CHECK: } else { - // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> - // CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64> - // CHECK: } - // CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> - // CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_35:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_37:.*]] = torch.constant.none - // CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> - // CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list - // CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> - // CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64> +// CHECK: %[[RES:.*]] = torch.aten.empty.memory_format +// CHECK: %[[LOOP_BATCH:.*]]:3 = torch.prim.Loop +// CHECK: %[[LOOP_CLASS:.*]]:3 = torch.prim.Loop +// CHECK: %[[NMS:.*]] = torch.torchvision.nms +// CHECK: %[[MIN_RES:.*]] = torch.aten.minimum +// CHECK: %[[NUM_ELE:.*]] = torch.aten.item %[[MIN_RES]] : !torch.vtensor<[],si64> -> !torch.int +// CHECK: %[[LOOP_NMS:.*]]:2 = torch.prim.Loop %[[NUM_ELE]] +// CHECK: %[[SEL_1:.*]] = torch.aten.select.int +// CHECK: %[[SEL_2:.*]] = torch.aten.select.int +// CHECK: %[[COPY:.*]] = torch.aten.copy +// CHECK: %[[SEL_3:.*]] = torch.aten.select.int +// CHECK-COUNT-6: torch.aten.select_scatter +// CHECK: %[[ADD_INDEX:.*]] = torch.aten.add.int +// CHECK: torch.prim.Loop.condition +// CHECK: %[[ADD_INDEX_1:.*]] = torch.aten.add.int +// CHECK: torch.prim.Loop.condition +// CHECK: torch.prim.Loop.condition +// CHECK: %[[SLICE_RES:.*]] = torch.aten.slice.Tensor +// CHECK: return %[[SLICE_RES]] : !torch.vtensor<[1,3],si64> %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> return %0 : !torch.vtensor<[1,3],si64> } @@ -2152,68 +2096,82 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, // CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>, // CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_nonmaxsuppression_center_point_box(%arg0: !torch.vtensor<[1,1,4],f32>, %arg1: !torch.vtensor<[1,1,1],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[VAL_5:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.int - // CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1." - // CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.vtensor<[1,4],f32> - // CHECK: %[[VAL_10:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_11:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.int - // CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1." - // CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> - // CHECK: %[[VAL_15:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_16:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.int - // CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1." - // CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32> - // CHECK: %[[VAL_20:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_21:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_22:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_23:.*]] = torch.constant.int 4 - // CHECK: %[[VAL_24:.*]] = torch.constant.float 2.000000e+00 - // CHECK: %[[VAL_25:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_20]], %[[VAL_22]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32> - // CHECK: %[[VAL_26:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32> - // CHECK: %[[VAL_27:.*]] = torch.aten.div.Scalar %[[VAL_26]], %[[VAL_24]] : !torch.vtensor<[?,2],f32>, !torch.float -> !torch.vtensor<[?,2],f32> - // CHECK: %[[VAL_28:.*]] = torch.aten.sub.Tensor %[[VAL_25]], %[[VAL_27]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32> - // CHECK: %[[VAL_29:.*]] = torch.aten.add.Tensor %[[VAL_25]], %[[VAL_27]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32> - // CHECK: %[[VAL_30:.*]] = torch.prim.ListConstruct %[[VAL_28]], %[[VAL_29]] : (!torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>) -> !torch.list - // CHECK: %[[VAL_31:.*]] = torch.aten.cat %[[VAL_30]], %[[VAL_21]] : !torch.list, !torch.int -> !torch.vtensor<[1,4],f32> - // CHECK: %[[VAL_32:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_33:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_34:.*]] = torch.aten.item %[[VAL_33]] : !torch.vtensor<[],f32> -> !torch.float - // CHECK: %[[VAL_35:.*]] = torch.aten.ge.float %[[VAL_34]], %[[VAL_32]] : !torch.float, !torch.float -> !torch.bool - // CHECK: torch.runtime.assert %[[VAL_35]], "unimplemented: score_threshold should be <= min(scores)" - // CHECK: %[[VAL_36:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_37:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_38:.*]] = torch.constant.float 0.000000e+00 - // CHECK: %[[VAL_39:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_40:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[VAL_41:.*]] = torch.torchvision.nms %[[VAL_31]], %[[VAL_19]], %[[VAL_39]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64> - // CHECK: %[[VAL_42:.*]] = torch.aten.size.int %[[VAL_41]], %[[VAL_36]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_43:.*]] = torch.aten.gt.int %[[VAL_42]], %[[VAL_40]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[VAL_44:.*]] = torch.prim.If %[[VAL_43]] -> (!torch.vtensor<[1],si64>) { - // CHECK: %[[VAL_45:.*]] = torch.aten.slice.Tensor %[[VAL_41]], %[[VAL_36]], %[[VAL_36]], %[[VAL_40]], %[[VAL_37]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.prim.If.yield %[[VAL_45]] : !torch.vtensor<[1],si64> - // CHECK: } else { - // CHECK: %[[VAL_46:.*]] = torch.tensor_static_info_cast %[[VAL_41]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> - // CHECK: torch.prim.If.yield %[[VAL_46]] : !torch.vtensor<[1],si64> - // CHECK: } - // CHECK: %[[VAL_47:.*]] = torch.aten.unsqueeze %[[VAL_44]], %[[VAL_37]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> - // CHECK: %[[VAL_48:.*]] = torch.aten.size.int %[[VAL_47]], %[[VAL_36]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_49:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_50:.*]] = torch.prim.ListConstruct %[[VAL_48]], %[[VAL_49]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_51:.*]] = torch.constant.none - // CHECK: %[[VAL_52:.*]] = torch.aten.zeros %[[VAL_50]], %[[VAL_51]], %[[VAL_51]], %[[VAL_51]], %[[VAL_51]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> - // CHECK: %[[VAL_53:.*]] = torch.prim.ListConstruct %[[VAL_52]], %[[VAL_47]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list - // CHECK: %[[VAL_54:.*]] = torch.aten.cat %[[VAL_53]], %[[VAL_37]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> - // CHECK: return %[[VAL_54]] : !torch.vtensor<[1,3],si64> +// CHECK: %[[VAL_5:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_6:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_7:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_8:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_9:.*]] = torch.constant.int 4 +// CHECK: %[[VAL_10:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[VAL_11:.*]] = torch.constant.none +// CHECK: %[[VAL_12:.*]] = torch.constant.bool true +// CHECK: %[[VAL_13:.*]] = torch.constant.bool false +// CHECK: %[[VAL_14:.*]] = torch.aten.slice.Tensor %[[VAL_0]], %[[VAL_7]], %[[VAL_5]], %[[VAL_7]], %[[VAL_6]] : !torch.vtensor<[1,1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,2],f32> +// CHECK: %[[VAL_15:.*]] = torch.aten.slice.Tensor %[[VAL_0]], %[[VAL_7]], %[[VAL_7]], %[[VAL_9]], %[[VAL_6]] : !torch.vtensor<[1,1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,2],f32> +// CHECK: %[[VAL_16:.*]] = torch.aten.div.Scalar %[[VAL_15]], %[[VAL_10]] : !torch.vtensor<[?,?,2],f32>, !torch.float -> !torch.vtensor<[?,?,2],f32> +// CHECK: %[[VAL_17:.*]] = torch.aten.sub.Tensor %[[VAL_14]], %[[VAL_16]], %[[VAL_6]] : !torch.vtensor<[?,?,2],f32>, !torch.vtensor<[?,?,2],f32>, !torch.int -> !torch.vtensor<[?,?,2],f32> +// CHECK: %[[VAL_18:.*]] = torch.aten.add.Tensor %[[VAL_14]], %[[VAL_16]], %[[VAL_6]] : !torch.vtensor<[?,?,2],f32>, !torch.vtensor<[?,?,2],f32>, !torch.int -> !torch.vtensor<[?,?,2],f32> +// CHECK: %[[VAL_19:.*]] = torch.prim.ListConstruct %[[VAL_17]], %[[VAL_18]] : (!torch.vtensor<[?,?,2],f32>, !torch.vtensor<[?,?,2],f32>) -> !torch.list +// CHECK: %[[VAL_20:.*]] = torch.aten.cat %[[VAL_19]], %[[VAL_7]] : !torch.list, !torch.int -> !torch.vtensor<[1,1,4],f32> +// CHECK: %[[VAL_21:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float +// CHECK: %[[VAL_22:.*]] = torch.aten.min %[[VAL_1]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[],f32> +// CHECK: %[[VAL_23:.*]] = torch.aten.item %[[VAL_22]] : !torch.vtensor<[],f32> -> !torch.float +// CHECK: %[[VAL_24:.*]] = torch.aten.ge.float %[[VAL_23]], %[[VAL_21]] : !torch.float, !torch.float -> !torch.bool +// CHECK: torch.runtime.assert %[[VAL_24]], "unimplemented: score_threshold should be <= min(scores)" +// CHECK: %[[RES:.*]] = torch.aten.empty.memory_format +// CHECK: %[[LOOP_BATCH:.*]]:3 = torch.prim.Loop +// CHECK: %[[LOOP_CLASS:.*]]:3 = torch.prim.Loop +// CHECK: %[[NMS:.*]] = torch.torchvision.nms +// CHECK: %[[MIN_RES:.*]] = torch.aten.minimum +// CHECK: %[[NUM_ELE:.*]] = torch.aten.item %[[MIN_RES]] : !torch.vtensor<[],si64> -> !torch.int +// CHECK: %[[LOOP_NMS:.*]]:2 = torch.prim.Loop %[[NUM_ELE]] +// CHECK: %[[SEL_1:.*]] = torch.aten.select.int +// CHECK: %[[SEL_2:.*]] = torch.aten.select.int +// CHECK: %[[COPY:.*]] = torch.aten.copy +// CHECK: %[[SEL_3:.*]] = torch.aten.select.int +// CHECK-COUNT-6: torch.aten.select_scatter +// CHECK: %[[ADD_INDEX:.*]] = torch.aten.add.int +// CHECK: torch.prim.Loop.condition +// CHECK: %[[ADD_INDEX_1:.*]] = torch.aten.add.int +// CHECK: torch.prim.Loop.condition +// CHECK: torch.prim.Loop.condition +// CHECK: %[[SLICE_RES:.*]] = torch.aten.slice.Tensor +// CHECK: return %[[SLICE_RES]] : !torch.vtensor<[1,3],si64> %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.center_point_box = 1 : si64} : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> return %0 : !torch.vtensor<[1,3],si64> } + +// ----- + +// CHECK-LABEL: func.func @test_nonmaxsuppression_multiple_batch_class( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,8,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[3,5,8],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],si64>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>, +// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[?,3],si64> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { +func.func @test_nonmaxsuppression_multiple_batch_class(%arg0: !torch.vtensor<[3,8,4],f32>, %arg1: !torch.vtensor<[3,5,8],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[?,3],si64> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[RES:.*]] = torch.aten.empty.memory_format +// CHECK: %[[LOOP_BATCH:.*]]:3 = torch.prim.Loop +// CHECK: %[[LOOP_CLASS:.*]]:3 = torch.prim.Loop +// CHECK: %[[NMS:.*]] = torch.torchvision.nms +// CHECK: %[[MIN_RES:.*]] = torch.aten.minimum +// CHECK: %[[NUM_ELE:.*]] = torch.aten.item %[[MIN_RES]] : !torch.vtensor<[],si64> -> !torch.int +// CHECK: %[[LOOP_NMS:.*]]:2 = torch.prim.Loop %[[NUM_ELE]] +// CHECK: %[[SEL_1:.*]] = torch.aten.select.int +// CHECK: %[[SEL_2:.*]] = torch.aten.select.int +// CHECK: %[[COPY:.*]] = torch.aten.copy +// CHECK: %[[SEL_3:.*]] = torch.aten.select.int +// CHECK-COUNT-6: torch.aten.select_scatter +// CHECK: %[[ADD_INDEX:.*]] = torch.aten.add.int +// CHECK: torch.prim.Loop.condition +// CHECK: %[[ADD_INDEX_1:.*]] = torch.aten.add.int +// CHECK: torch.prim.Loop.condition +// CHECK: torch.prim.Loop.condition +// CHECK: %[[SLICE_RES:.*]] = torch.aten.slice.Tensor +// CHECK: return %[[SLICE_RES]] : !torch.vtensor<[?,3],si64> + %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.center_point_box = 0 : si64} : (!torch.vtensor<[3,8,4],f32>, !torch.vtensor<[3,5,8],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[?,3],si64> + return %0 : !torch.vtensor<[?,3],si64> +} + // ----- // CHECK-LABEL: func.func @test_mwm