From fd0629ae788b37330dff5901e62d8c1b29d8f6c9 Mon Sep 17 00:00:00 2001 From: Praveen G Date: Thu, 23 Jan 2025 14:47:19 +0000 Subject: [PATCH] Support batches and classes for nms lowering --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 299 ++++++++++++------ 1 file changed, 207 insertions(+), 92 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 12d8683bc9d1..b160ee4b5ca7 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; @@ -3703,30 +3704,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, "unimplemented: expected center_point_box " "attribute value to be 0"); - // TODO: Support multiple batches and classes - // Squeeze the boxes and scores tensor. // In Onnx, the shape of boxes is [BxNx4] while the // torchvision expects it to be of shape [Nx4]. Similarly, for // the scores tensor shape in Onnx is [BxCxN] while the // torchvision expects it to be of shape [N]. Value boxes = operands[0], scores = operands[1]; - FailureOr squeezedBoxes = - Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes); - if (failed(squeezedBoxes)) - return rewriter.notifyMatchFailure(binder.op, - "failed to squeeze boxes tensor"); - FailureOr squeezedScores = - Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores); - if (failed(squeezedScores)) - return rewriter.notifyMatchFailure(binder.op, - "failed to squeeze scores tensor"); - squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0, - squeezedScores.value()); - if (failed(squeezedScores)) - return rewriter.notifyMatchFailure(binder.op, - "failed to squeeze scores tensor"); - boxes = squeezedBoxes.value(); - scores = squeezedScores.value(); // TODO: Support score_threshold input // Filter out the boxes if the score < score_threshold @@ -3750,12 +3732,26 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( "unimplemented: score_threshold should be <= min(scores)")); } - // Get max_output_boxes_per_class and iou_threshold Value cst0 = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value cst1 = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); + Value cst2 = rewriter.create( + loc, rewriter.getI64IntegerAttr(2)); + Value cst3 = rewriter.create( + loc, rewriter.getI64IntegerAttr(3)); + Value cst4 = rewriter.create( + loc, rewriter.getI64IntegerAttr(4)); + + Value cstNone = rewriter.create(loc); + Value cstTrue = rewriter.create( + loc, rewriter.getBoolAttr(true)); + Value cstFalse = rewriter.create( + loc, rewriter.getBoolAttr(false)); + Value maxOutputBoxesPerClass = cst0; + + // Get max_output_boxes_per_class and iou_threshold Value iouThreshold = rewriter.create( loc, rewriter.getF64FloatAttr(0.0)); if (operands.size() > 3 && @@ -3769,87 +3765,206 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( loc, rewriter.getType(), operands[2]); } + auto boxesTensorType = cast(boxes.getType()); + auto scoreTensorType = cast(scores.getType()); + auto boxSlicedType = rewriter.getType( + boxesTensorType.getSizes().slice(1), boxesTensorType.getDtype()); + auto scoreSlicedType = rewriter.getType( + scoreTensorType.getSizes().slice(1), scoreTensorType.getDtype()); + + auto numBatches = + rewriter.create(loc, scores, cst0); + auto numClasses = + rewriter.create(loc, scores, cst1); + // auto numBoxes = + // rewriter.create(loc, scores, cst2); + + std::optional> resultShape = + cast(resultType).getOptionalSizes(); + if (!resultShape.has_value()) + return rewriter.notifyMatchFailure( + binder.op, "Expected result tensor to have shape"); + + Value numResults = rewriter.create( + loc, rewriter.getI64IntegerAttr(resultShape->front())); + + auto intTy = rewriter.getType(); + auto intListTy = rewriter.getType(intTy); + + Value resultShapeList = rewriter.create( + loc, intListTy, SmallVector{numResults, cst3}); + + Value finalResult = rewriter.create( + loc, resultType, resultShapeList, /*dtype=*/cst4, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + auto nmsTy = Torch::ValueTensorType::get( binder.op->getContext(), SmallVector{-1}, rewriter.getIntegerType(64, /*signed=*/true)); - Value result = rewriter.create( - loc, nmsTy, boxes, scores, iouThreshold); - // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class - Value numOutputBoxes = - rewriter.create(loc, result, cst0); - Value boxesCond = rewriter.create( - loc, numOutputBoxes, maxOutputBoxesPerClass); + auto emptyTensorTy = rewriter.getType( + SmallVector{}, nmsTy.getDtype()); - auto nmsResultTy = Torch::ValueTensorType::get( - binder.op->getContext(), - SmallVector{resultType.getSizes()[0]}, - rewriter.getIntegerType(64, /*signed=*/true)); - auto ifSlice = rewriter.create( - loc, TypeRange({nmsResultTy}), boxesCond); + auto nmsBatchLoop = rewriter.create( + loc, TypeRange({resultType, intTy}), numBatches, cstTrue, + ValueRange({finalResult, /*Index to finalResult*/ cst0})); { + // Batch loop body PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifSlice.getThenRegion(), - ifSlice.getThenRegion().begin()); + Block *batchLoopBody = rewriter.createBlock( + &nmsBatchLoop.getRegion(), nmsBatchLoop.getRegion().begin(), + TypeRange({intTy, resultType, intTy}), {loc, loc, loc}); + auto batchIV = batchLoopBody->getArgument(0); + auto currRes = batchLoopBody->getArgument(1); + auto finalResIdx = batchLoopBody->getArgument(2); + + auto boxValue = rewriter.create( + loc, boxSlicedType, boxes, cst0, batchIV); - Value curResult = rewriter.create( - loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0, - /*end=*/maxOutputBoxesPerClass, /*step=*/cst1); - rewriter.create(loc, curResult); + auto batchValue = rewriter.create( + loc, emptyTensorTy, batchIV); + + auto nmsClassLoop = rewriter.create( + loc, TypeRange({resultType, intTy}), numClasses, cstTrue, + ValueRange({currRes, finalResIdx})); + + { + // Class loop body + PatternRewriter::InsertionGuard guard(rewriter); + Block *classLoopBody = rewriter.createBlock( + &nmsClassLoop.getRegion(), nmsClassLoop.getRegion().begin(), + TypeRange({intTy, resultType, intTy}), {loc, loc, loc}); + auto classIV = classLoopBody->getArgument(0); + auto currRes = classLoopBody->getArgument(1); + auto finalResIdx = classLoopBody->getArgument(2); + + auto scoreSelect = rewriter.create( + loc, scoreSlicedType, scores, cst0, batchIV); + auto scoreSelectType = + dyn_cast(scoreSelect.getType()); + auto scoreValueType = rewriter.getType( + scoreSelectType.getSizes().slice(1), + scoreSelectType.getDtype()); + + auto scoreValue = rewriter.create( + loc, scoreValueType, scoreSelect, cst0, classIV); + + auto classValue = rewriter.create( + loc, emptyTensorTy, classIV); + + Value result = rewriter.create( + loc, nmsTy, boxValue, scoreValue, iouThreshold); + + Value numOutputBoxes = + rewriter.create(loc, result, cst0); + + numOutputBoxes = rewriter.create( + loc, emptyTensorTy, numOutputBoxes); + Value maxBoxesPerClass = + rewriter.create( + loc, emptyTensorTy, maxOutputBoxesPerClass); + auto minVal = rewriter.create( + loc, numOutputBoxes.getType(), numOutputBoxes, + maxBoxesPerClass); + numOutputBoxes = + rewriter.create(loc, intTy, minVal); + + // Loop through the nms result + auto nmsLoop = rewriter.create( + loc, TypeRange({resultType, intTy}), numOutputBoxes, cstTrue, + ValueRange({currRes, finalResIdx})); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *loopBody = rewriter.createBlock( + &nmsLoop.getRegion(), nmsLoop.getRegion().begin(), + TypeRange({intTy, resultType, intTy}), {loc, loc, loc}); + auto iter = loopBody->getArgument(0); + auto currRes = loopBody->getArgument(1); + auto idxCst = loopBody->getArgument(2); + + // Update batch + + auto outputTensorSliceType = + rewriter.getType( + SmallVector{3}, nmsTy.getDtype()); + + auto batchDim3D = rewriter.create( + loc, outputTensorSliceType, currRes, cst0, idxCst); + + auto batchSelect = rewriter.create( + loc, emptyTensorTy, batchDim3D, cst0, cst0); + + auto bCopy = rewriter.create( + loc, batchSelect.getType(), batchSelect, batchValue, + cstFalse); + + batchDim3D = rewriter.create( + loc, outputTensorSliceType, currRes, cst0, idxCst); + + auto scatterBatch = rewriter.create( + loc, outputTensorSliceType, batchDim3D, bCopy, cst0, cst0); + // Yield result + auto batchResult = rewriter.create( + loc, resultType, currRes, scatterBatch, cst0, idxCst); + + // Class values + auto classDim3D = rewriter.create( + loc, outputTensorSliceType, batchResult, cst0, idxCst); + // Class indices + auto classSelect = rewriter.create( + loc, emptyTensorTy, classDim3D, cst0, cst1); + + auto cCopy = rewriter.create( + loc, classSelect.getType(), classSelect, classValue, + cstFalse); + + classDim3D = rewriter.create( + loc, outputTensorSliceType, batchResult, cst0, idxCst); + + auto scatterClass = rewriter.create( + loc, outputTensorSliceType, classDim3D, cCopy, cst0, cst1); + auto classRes = rewriter.create( + loc, resultType, batchResult, scatterClass, cst0, idxCst); + + auto resDim3D = rewriter.create( + loc, outputTensorSliceType, classRes, cst0, idxCst); + // Class indices + auto resSelect = rewriter.create( + loc, emptyTensorTy, resDim3D, cst0, cst2); + + auto nmsResultValue = rewriter.create( + loc, emptyTensorTy, result, cst0, iter); + + auto rCopy = rewriter.create( + loc, resSelect.getType(), resSelect, nmsResultValue, + cstFalse); + + resDim3D = rewriter.create( + loc, outputTensorSliceType, classRes, cst0, idxCst); + + auto scatterRes = rewriter.create( + loc, outputTensorSliceType, resDim3D, rCopy, cst0, cst2); + Value nmsResult = rewriter.create( + loc, resultType, classRes, scatterRes, cst0, idxCst); + + Value next = + rewriter.create(loc, idxCst, cst1); + + rewriter.create( + loc, cstTrue, ValueRange({nmsResult, next})); + } + rewriter.create( + loc, cstTrue, + ValueRange({nmsLoop.getResult(0), nmsLoop.getResult(1)})); + } + rewriter.create( + loc, cstTrue, + ValueRange( + {nmsClassLoop.getResult(0), nmsClassLoop.getResult(1)})); } - { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifSlice.getElseRegion(), - ifSlice.getElseRegion().begin()); - - Value curResult = rewriter.create( - loc, nmsResultTy, result); - rewriter.create(loc, curResult); - } - result = ifSlice.getResult(0); - - // The result generated by torchvision.nms op is of shape [n], while the - // onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor - // and make it of shape [n, 1] and then concatenate it with a zero - // tensor of shape [n, 2] to make it of shape [n, 3]. - FailureOr unsqueezedResult = - Torch::unsqueezeTensor(rewriter, binder.op, result, cst1); - if (failed(unsqueezedResult)) - return rewriter.notifyMatchFailure( - binder.op, "failed to unsqueeze result tensor"); - result = unsqueezedResult.value(); - - numOutputBoxes = - rewriter.create(loc, result, cst0); - SmallVector zerosShapeValues{numOutputBoxes}; - zerosShapeValues.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(2))); - Value zerosShapeList = rewriter.create( - loc, - rewriter.getType( - rewriter.getType()), - zerosShapeValues); - std::optional> resultShape = - cast(result.getType()).getOptionalSizes(); - if (!resultShape.has_value()) - return rewriter.notifyMatchFailure( - binder.op, "expected result tensor to have shape"); - llvm::SmallVector zerosShape = {resultShape->front(), 2}; - auto zerosTy = Torch::ValueTensorType::get( - resultType.getContext(), zerosShape, resultType.getOptionalDtype()); - Value cstNone = rewriter.create(loc); - Value zeros = rewriter.create( - loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone); - - Type listElemType = - cast(resultType) - .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, - /*optionalDtype=*/nullptr); - Type listType = Torch::ListType::get(listElemType); - Value tensorList = rewriter.create( - loc, listType, SmallVector{zeros, result}); - rewriter.replaceOpWithNewOp(binder.op, resultType, - tensorList, cst1); + rewriter.replaceOp(binder.op, nmsBatchLoop.getResult(0)); return success(); }); }