Skip to content

Commit fd0629a

Browse files
committed
Support batches and classes for nms lowering
1 parent 481da8d commit fd0629a

File tree

1 file changed

+207
-92
lines changed

1 file changed

+207
-92
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 207 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
1111
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
12+
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1213
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
1314

1415
using namespace mlir;
@@ -3703,30 +3704,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
37033704
binder.op, "unimplemented: expected center_point_box "
37043705
"attribute value to be 0");
37053706

3706-
// TODO: Support multiple batches and classes
3707-
// Squeeze the boxes and scores tensor.
37083707
// In Onnx, the shape of boxes is [BxNx4] while the
37093708
// torchvision expects it to be of shape [Nx4]. Similarly, for
37103709
// the scores tensor shape in Onnx is [BxCxN] while the
37113710
// torchvision expects it to be of shape [N].
37123711
Value boxes = operands[0], scores = operands[1];
3713-
FailureOr<Value> squeezedBoxes =
3714-
Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes);
3715-
if (failed(squeezedBoxes))
3716-
return rewriter.notifyMatchFailure(binder.op,
3717-
"failed to squeeze boxes tensor");
3718-
FailureOr<Value> squeezedScores =
3719-
Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores);
3720-
if (failed(squeezedScores))
3721-
return rewriter.notifyMatchFailure(binder.op,
3722-
"failed to squeeze scores tensor");
3723-
squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0,
3724-
squeezedScores.value());
3725-
if (failed(squeezedScores))
3726-
return rewriter.notifyMatchFailure(binder.op,
3727-
"failed to squeeze scores tensor");
3728-
boxes = squeezedBoxes.value();
3729-
scores = squeezedScores.value();
37303712

37313713
// TODO: Support score_threshold input
37323714
// Filter out the boxes if the score < score_threshold
@@ -3750,12 +3732,26 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
37503732
"unimplemented: score_threshold should be <= min(scores)"));
37513733
}
37523734

3753-
// Get max_output_boxes_per_class and iou_threshold
37543735
Value cst0 = rewriter.create<Torch::ConstantIntOp>(
37553736
loc, rewriter.getI64IntegerAttr(0));
37563737
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
37573738
loc, rewriter.getI64IntegerAttr(1));
3739+
Value cst2 = rewriter.create<Torch::ConstantIntOp>(
3740+
loc, rewriter.getI64IntegerAttr(2));
3741+
Value cst3 = rewriter.create<Torch::ConstantIntOp>(
3742+
loc, rewriter.getI64IntegerAttr(3));
3743+
Value cst4 = rewriter.create<Torch::ConstantIntOp>(
3744+
loc, rewriter.getI64IntegerAttr(4));
3745+
3746+
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
3747+
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(
3748+
loc, rewriter.getBoolAttr(true));
3749+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
3750+
loc, rewriter.getBoolAttr(false));
3751+
37583752
Value maxOutputBoxesPerClass = cst0;
3753+
3754+
// Get max_output_boxes_per_class and iou_threshold
37593755
Value iouThreshold = rewriter.create<Torch::ConstantFloatOp>(
37603756
loc, rewriter.getF64FloatAttr(0.0));
37613757
if (operands.size() > 3 &&
@@ -3769,87 +3765,206 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
37693765
loc, rewriter.getType<Torch::IntType>(), operands[2]);
37703766
}
37713767

3768+
auto boxesTensorType = cast<Torch::ValueTensorType>(boxes.getType());
3769+
auto scoreTensorType = cast<Torch::ValueTensorType>(scores.getType());
3770+
auto boxSlicedType = rewriter.getType<Torch::ValueTensorType>(
3771+
boxesTensorType.getSizes().slice(1), boxesTensorType.getDtype());
3772+
auto scoreSlicedType = rewriter.getType<Torch::ValueTensorType>(
3773+
scoreTensorType.getSizes().slice(1), scoreTensorType.getDtype());
3774+
3775+
auto numBatches =
3776+
rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst0);
3777+
auto numClasses =
3778+
rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst1);
3779+
// auto numBoxes =
3780+
// rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst2);
3781+
3782+
std::optional<ArrayRef<int64_t>> resultShape =
3783+
cast<Torch::ValueTensorType>(resultType).getOptionalSizes();
3784+
if (!resultShape.has_value())
3785+
return rewriter.notifyMatchFailure(
3786+
binder.op, "Expected result tensor to have shape");
3787+
3788+
Value numResults = rewriter.create<Torch::ConstantIntOp>(
3789+
loc, rewriter.getI64IntegerAttr(resultShape->front()));
3790+
3791+
auto intTy = rewriter.getType<Torch::IntType>();
3792+
auto intListTy = rewriter.getType<Torch::ListType>(intTy);
3793+
3794+
Value resultShapeList = rewriter.create<Torch::PrimListConstructOp>(
3795+
loc, intListTy, SmallVector<Value>{numResults, cst3});
3796+
3797+
Value finalResult = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
3798+
loc, resultType, resultShapeList, /*dtype=*/cst4,
3799+
/*layout=*/cstNone,
3800+
/*device=*/cstNone, /*pinMemory=*/cstNone,
3801+
/*memoryFormat=*/cstNone);
3802+
37723803
auto nmsTy = Torch::ValueTensorType::get(
37733804
binder.op->getContext(), SmallVector<int64_t>{-1},
37743805
rewriter.getIntegerType(64, /*signed=*/true));
3775-
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
3776-
loc, nmsTy, boxes, scores, iouThreshold);
37773806

3778-
// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
3779-
Value numOutputBoxes =
3780-
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
3781-
Value boxesCond = rewriter.create<Torch::AtenGtIntOp>(
3782-
loc, numOutputBoxes, maxOutputBoxesPerClass);
3807+
auto emptyTensorTy = rewriter.getType<Torch::ValueTensorType>(
3808+
SmallVector<int64_t>{}, nmsTy.getDtype());
37833809

3784-
auto nmsResultTy = Torch::ValueTensorType::get(
3785-
binder.op->getContext(),
3786-
SmallVector<int64_t>{resultType.getSizes()[0]},
3787-
rewriter.getIntegerType(64, /*signed=*/true));
3788-
auto ifSlice = rewriter.create<Torch::PrimIfOp>(
3789-
loc, TypeRange({nmsResultTy}), boxesCond);
3810+
auto nmsBatchLoop = rewriter.create<Torch::PrimLoopOp>(
3811+
loc, TypeRange({resultType, intTy}), numBatches, cstTrue,
3812+
ValueRange({finalResult, /*Index to finalResult*/ cst0}));
37903813
{
3814+
// Batch loop body
37913815
PatternRewriter::InsertionGuard guard(rewriter);
3792-
rewriter.createBlock(&ifSlice.getThenRegion(),
3793-
ifSlice.getThenRegion().begin());
3816+
Block *batchLoopBody = rewriter.createBlock(
3817+
&nmsBatchLoop.getRegion(), nmsBatchLoop.getRegion().begin(),
3818+
TypeRange({intTy, resultType, intTy}), {loc, loc, loc});
3819+
auto batchIV = batchLoopBody->getArgument(0);
3820+
auto currRes = batchLoopBody->getArgument(1);
3821+
auto finalResIdx = batchLoopBody->getArgument(2);
3822+
3823+
auto boxValue = rewriter.create<Torch::AtenSelectIntOp>(
3824+
loc, boxSlicedType, boxes, cst0, batchIV);
37943825

3795-
Value curResult = rewriter.create<Torch::AtenSliceTensorOp>(
3796-
loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0,
3797-
/*end=*/maxOutputBoxesPerClass, /*step=*/cst1);
3798-
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
3826+
auto batchValue = rewriter.create<Torch::PrimNumToTensorScalarOp>(
3827+
loc, emptyTensorTy, batchIV);
3828+
3829+
auto nmsClassLoop = rewriter.create<Torch::PrimLoopOp>(
3830+
loc, TypeRange({resultType, intTy}), numClasses, cstTrue,
3831+
ValueRange({currRes, finalResIdx}));
3832+
3833+
{
3834+
// Class loop body
3835+
PatternRewriter::InsertionGuard guard(rewriter);
3836+
Block *classLoopBody = rewriter.createBlock(
3837+
&nmsClassLoop.getRegion(), nmsClassLoop.getRegion().begin(),
3838+
TypeRange({intTy, resultType, intTy}), {loc, loc, loc});
3839+
auto classIV = classLoopBody->getArgument(0);
3840+
auto currRes = classLoopBody->getArgument(1);
3841+
auto finalResIdx = classLoopBody->getArgument(2);
3842+
3843+
auto scoreSelect = rewriter.create<Torch::AtenSelectIntOp>(
3844+
loc, scoreSlicedType, scores, cst0, batchIV);
3845+
auto scoreSelectType =
3846+
dyn_cast<Torch::ValueTensorType>(scoreSelect.getType());
3847+
auto scoreValueType = rewriter.getType<Torch::ValueTensorType>(
3848+
scoreSelectType.getSizes().slice(1),
3849+
scoreSelectType.getDtype());
3850+
3851+
auto scoreValue = rewriter.create<Torch::AtenSelectIntOp>(
3852+
loc, scoreValueType, scoreSelect, cst0, classIV);
3853+
3854+
auto classValue = rewriter.create<Torch::PrimNumToTensorScalarOp>(
3855+
loc, emptyTensorTy, classIV);
3856+
3857+
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
3858+
loc, nmsTy, boxValue, scoreValue, iouThreshold);
3859+
3860+
Value numOutputBoxes =
3861+
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
3862+
3863+
numOutputBoxes = rewriter.create<Torch::PrimNumToTensorScalarOp>(
3864+
loc, emptyTensorTy, numOutputBoxes);
3865+
Value maxBoxesPerClass =
3866+
rewriter.create<Torch::PrimNumToTensorScalarOp>(
3867+
loc, emptyTensorTy, maxOutputBoxesPerClass);
3868+
auto minVal = rewriter.create<Torch::AtenMinimumOp>(
3869+
loc, numOutputBoxes.getType(), numOutputBoxes,
3870+
maxBoxesPerClass);
3871+
numOutputBoxes =
3872+
rewriter.create<Torch::AtenItemOp>(loc, intTy, minVal);
3873+
3874+
// Loop through the nms result
3875+
auto nmsLoop = rewriter.create<Torch::PrimLoopOp>(
3876+
loc, TypeRange({resultType, intTy}), numOutputBoxes, cstTrue,
3877+
ValueRange({currRes, finalResIdx}));
3878+
{
3879+
PatternRewriter::InsertionGuard guard(rewriter);
3880+
Block *loopBody = rewriter.createBlock(
3881+
&nmsLoop.getRegion(), nmsLoop.getRegion().begin(),
3882+
TypeRange({intTy, resultType, intTy}), {loc, loc, loc});
3883+
auto iter = loopBody->getArgument(0);
3884+
auto currRes = loopBody->getArgument(1);
3885+
auto idxCst = loopBody->getArgument(2);
3886+
3887+
// Update batch
3888+
3889+
auto outputTensorSliceType =
3890+
rewriter.getType<Torch::ValueTensorType>(
3891+
SmallVector<int64_t>{3}, nmsTy.getDtype());
3892+
3893+
auto batchDim3D = rewriter.create<Torch::AtenSelectIntOp>(
3894+
loc, outputTensorSliceType, currRes, cst0, idxCst);
3895+
3896+
auto batchSelect = rewriter.create<Torch::AtenSelectIntOp>(
3897+
loc, emptyTensorTy, batchDim3D, cst0, cst0);
3898+
3899+
auto bCopy = rewriter.create<Torch::AtenCopyOp>(
3900+
loc, batchSelect.getType(), batchSelect, batchValue,
3901+
cstFalse);
3902+
3903+
batchDim3D = rewriter.create<Torch::AtenSelectIntOp>(
3904+
loc, outputTensorSliceType, currRes, cst0, idxCst);
3905+
3906+
auto scatterBatch = rewriter.create<Torch::AtenSelectScatterOp>(
3907+
loc, outputTensorSliceType, batchDim3D, bCopy, cst0, cst0);
3908+
// Yield result
3909+
auto batchResult = rewriter.create<Torch::AtenSelectScatterOp>(
3910+
loc, resultType, currRes, scatterBatch, cst0, idxCst);
3911+
3912+
// Class values
3913+
auto classDim3D = rewriter.create<Torch::AtenSelectIntOp>(
3914+
loc, outputTensorSliceType, batchResult, cst0, idxCst);
3915+
// Class indices
3916+
auto classSelect = rewriter.create<Torch::AtenSelectIntOp>(
3917+
loc, emptyTensorTy, classDim3D, cst0, cst1);
3918+
3919+
auto cCopy = rewriter.create<Torch::AtenCopyOp>(
3920+
loc, classSelect.getType(), classSelect, classValue,
3921+
cstFalse);
3922+
3923+
classDim3D = rewriter.create<Torch::AtenSelectIntOp>(
3924+
loc, outputTensorSliceType, batchResult, cst0, idxCst);
3925+
3926+
auto scatterClass = rewriter.create<Torch::AtenSelectScatterOp>(
3927+
loc, outputTensorSliceType, classDim3D, cCopy, cst0, cst1);
3928+
auto classRes = rewriter.create<Torch::AtenSelectScatterOp>(
3929+
loc, resultType, batchResult, scatterClass, cst0, idxCst);
3930+
3931+
auto resDim3D = rewriter.create<Torch::AtenSelectIntOp>(
3932+
loc, outputTensorSliceType, classRes, cst0, idxCst);
3933+
// Class indices
3934+
auto resSelect = rewriter.create<Torch::AtenSelectIntOp>(
3935+
loc, emptyTensorTy, resDim3D, cst0, cst2);
3936+
3937+
auto nmsResultValue = rewriter.create<Torch::AtenSelectIntOp>(
3938+
loc, emptyTensorTy, result, cst0, iter);
3939+
3940+
auto rCopy = rewriter.create<Torch::AtenCopyOp>(
3941+
loc, resSelect.getType(), resSelect, nmsResultValue,
3942+
cstFalse);
3943+
3944+
resDim3D = rewriter.create<Torch::AtenSelectIntOp>(
3945+
loc, outputTensorSliceType, classRes, cst0, idxCst);
3946+
3947+
auto scatterRes = rewriter.create<Torch::AtenSelectScatterOp>(
3948+
loc, outputTensorSliceType, resDim3D, rCopy, cst0, cst2);
3949+
Value nmsResult = rewriter.create<Torch::AtenSelectScatterOp>(
3950+
loc, resultType, classRes, scatterRes, cst0, idxCst);
3951+
3952+
Value next =
3953+
rewriter.create<Torch::AtenAddIntOp>(loc, idxCst, cst1);
3954+
3955+
rewriter.create<Torch::PrimLoopConditionOp>(
3956+
loc, cstTrue, ValueRange({nmsResult, next}));
3957+
}
3958+
rewriter.create<Torch::PrimLoopConditionOp>(
3959+
loc, cstTrue,
3960+
ValueRange({nmsLoop.getResult(0), nmsLoop.getResult(1)}));
3961+
}
3962+
rewriter.create<Torch::PrimLoopConditionOp>(
3963+
loc, cstTrue,
3964+
ValueRange(
3965+
{nmsClassLoop.getResult(0), nmsClassLoop.getResult(1)}));
37993966
}
3800-
{
3801-
PatternRewriter::InsertionGuard guard(rewriter);
3802-
rewriter.createBlock(&ifSlice.getElseRegion(),
3803-
ifSlice.getElseRegion().begin());
3804-
3805-
Value curResult = rewriter.create<Torch::TensorStaticInfoCastOp>(
3806-
loc, nmsResultTy, result);
3807-
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
3808-
}
3809-
result = ifSlice.getResult(0);
3810-
3811-
// The result generated by torchvision.nms op is of shape [n], while the
3812-
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
3813-
// and make it of shape [n, 1] and then concatenate it with a zero
3814-
// tensor of shape [n, 2] to make it of shape [n, 3].
3815-
FailureOr<Value> unsqueezedResult =
3816-
Torch::unsqueezeTensor(rewriter, binder.op, result, cst1);
3817-
if (failed(unsqueezedResult))
3818-
return rewriter.notifyMatchFailure(
3819-
binder.op, "failed to unsqueeze result tensor");
3820-
result = unsqueezedResult.value();
3821-
3822-
numOutputBoxes =
3823-
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
3824-
SmallVector<Value> zerosShapeValues{numOutputBoxes};
3825-
zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
3826-
loc, rewriter.getI64IntegerAttr(2)));
3827-
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
3828-
loc,
3829-
rewriter.getType<Torch::ListType>(
3830-
rewriter.getType<Torch::IntType>()),
3831-
zerosShapeValues);
3832-
std::optional<ArrayRef<int64_t>> resultShape =
3833-
cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
3834-
if (!resultShape.has_value())
3835-
return rewriter.notifyMatchFailure(
3836-
binder.op, "expected result tensor to have shape");
3837-
llvm::SmallVector<int64_t> zerosShape = {resultShape->front(), 2};
3838-
auto zerosTy = Torch::ValueTensorType::get(
3839-
resultType.getContext(), zerosShape, resultType.getOptionalDtype());
3840-
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
3841-
Value zeros = rewriter.create<Torch::AtenZerosOp>(
3842-
loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone);
3843-
3844-
Type listElemType =
3845-
cast<Torch::BaseTensorType>(resultType)
3846-
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
3847-
/*optionalDtype=*/nullptr);
3848-
Type listType = Torch::ListType::get(listElemType);
3849-
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
3850-
loc, listType, SmallVector<Value>{zeros, result});
3851-
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
3852-
tensorList, cst1);
3967+
rewriter.replaceOp(binder.op, nmsBatchLoop.getResult(0));
38533968
return success();
38543969
});
38553970
}

0 commit comments

Comments
 (0)