Skip to content

Commit 75362a2

Browse files
committed
Support batch and classes for nms lowering
1 parent 481da8d commit 75362a2

File tree

1 file changed

+186
-91
lines changed

1 file changed

+186
-91
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 186 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3703,30 +3703,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
37033703
binder.op, "unimplemented: expected center_point_box "
37043704
"attribute value to be 0");
37053705

3706-
// TODO: Support multiple batches and classes
3707-
// Squeeze the boxes and scores tensor.
37083706
// In Onnx, the shape of boxes is [BxNx4] while the
37093707
// torchvision expects it to be of shape [Nx4]. Similarly, for
37103708
// the scores tensor shape in Onnx is [BxCxN] while the
37113709
// torchvision expects it to be of shape [N].
37123710
Value boxes = operands[0], scores = operands[1];
3713-
FailureOr<Value> squeezedBoxes =
3714-
Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes);
3715-
if (failed(squeezedBoxes))
3716-
return rewriter.notifyMatchFailure(binder.op,
3717-
"failed to squeeze boxes tensor");
3718-
FailureOr<Value> squeezedScores =
3719-
Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores);
3720-
if (failed(squeezedScores))
3721-
return rewriter.notifyMatchFailure(binder.op,
3722-
"failed to squeeze scores tensor");
3723-
squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0,
3724-
squeezedScores.value());
3725-
if (failed(squeezedScores))
3726-
return rewriter.notifyMatchFailure(binder.op,
3727-
"failed to squeeze scores tensor");
3728-
boxes = squeezedBoxes.value();
3729-
scores = squeezedScores.value();
37303711

37313712
// TODO: Support score_threshold input
37323713
// Filter out the boxes if the score < score_threshold
@@ -3755,6 +3736,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
37553736
loc, rewriter.getI64IntegerAttr(0));
37563737
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
37573738
loc, rewriter.getI64IntegerAttr(1));
3739+
Value cst3 = rewriter.create<Torch::ConstantIntOp>(
3740+
loc, rewriter.getI64IntegerAttr(3));
3741+
Value cst4 = rewriter.create<Torch::ConstantIntOp>(
3742+
loc, rewriter.getI64IntegerAttr(4));
3743+
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
3744+
37583745
Value maxOutputBoxesPerClass = cst0;
37593746
Value iouThreshold = rewriter.create<Torch::ConstantFloatOp>(
37603747
loc, rewriter.getF64FloatAttr(0.0));
@@ -3769,87 +3756,195 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
37693756
loc, rewriter.getType<Torch::IntType>(), operands[2]);
37703757
}
37713758

3759+
auto boxTensorType = cast<Torch::ValueTensorType>(boxes.getType());
3760+
auto boxSlicedType = rewriter.getType<Torch::ValueTensorType>(
3761+
boxTensorType.getSizes().slice(1), boxTensorType.getDtype());
3762+
auto scoreTensorType = cast<Torch::ValueTensorType>(scores.getType());
3763+
auto scoreSlicedType = rewriter.getType<Torch::ValueTensorType>(
3764+
scoreTensorType.getSizes().slice(1), scoreTensorType.getDtype());
3765+
3766+
auto intListTy = rewriter.getType<Torch::ListType>(
3767+
rewriter.getType<Torch::IntType>());
3768+
Value lenShapeList = rewriter.create<Torch::PrimListConstructOp>(
3769+
loc, intListTy, SmallVector<Value>{cst0, cst3});
3770+
3771+
auto emptyResType = rewriter.getType<Torch::ValueTensorType>(
3772+
ArrayRef<int64_t>{-1, 3}, resultType.getDtype());
3773+
Value finalResult = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
3774+
loc, emptyResType, lenShapeList, /*dtype=*/cst4, /*layout=*/cstNone,
3775+
/*device=*/cstNone, /*pinMemory=*/cstNone,
3776+
/*memoryFormat=*/cstNone);
3777+
37723778
auto nmsTy = Torch::ValueTensorType::get(
37733779
binder.op->getContext(), SmallVector<int64_t>{-1},
37743780
rewriter.getIntegerType(64, /*signed=*/true));
3775-
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
3776-
loc, nmsTy, boxes, scores, iouThreshold);
37773781

3778-
// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
3779-
Value numOutputBoxes =
3780-
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
3781-
Value boxesCond = rewriter.create<Torch::AtenGtIntOp>(
3782-
loc, numOutputBoxes, maxOutputBoxesPerClass);
3782+
auto numBatches =
3783+
rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst0);
3784+
auto numClasses =
3785+
rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst1);
3786+
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(
3787+
loc, rewriter.getBoolAttr(true));
3788+
auto intTy = rewriter.getType<Torch::IntType>();
3789+
3790+
auto nmsBatchLoop = rewriter.create<Torch::PrimLoopOp>(
3791+
loc, TypeRange({finalResult.getType(), intTy}), numBatches, cstTrue,
3792+
ValueRange({finalResult, /*Index to finalResult*/ cst0}));
37833793

3784-
auto nmsResultTy = Torch::ValueTensorType::get(
3785-
binder.op->getContext(),
3786-
SmallVector<int64_t>{resultType.getSizes()[0]},
3787-
rewriter.getIntegerType(64, /*signed=*/true));
3788-
auto ifSlice = rewriter.create<Torch::PrimIfOp>(
3789-
loc, TypeRange({nmsResultTy}), boxesCond);
37903794
{
3795+
3796+
// Batch loop body
37913797
PatternRewriter::InsertionGuard guard(rewriter);
3792-
rewriter.createBlock(&ifSlice.getThenRegion(),
3793-
ifSlice.getThenRegion().begin());
3798+
Block *batchLoopBody = rewriter.createBlock(
3799+
&nmsBatchLoop.getRegion(), nmsBatchLoop.getRegion().begin(),
3800+
TypeRange({intTy, finalResult.getType(), intTy}),
3801+
{loc, loc, loc});
3802+
auto batchIV = batchLoopBody->getArgument(0);
3803+
auto currRes = batchLoopBody->getArgument(1);
3804+
auto finalResIdx = batchLoopBody->getArgument(2);
37943805

3795-
Value curResult = rewriter.create<Torch::AtenSliceTensorOp>(
3796-
loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0,
3797-
/*end=*/maxOutputBoxesPerClass, /*step=*/cst1);
3798-
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
3806+
auto boxValue = rewriter.create<Torch::AtenSelectIntOp>(
3807+
loc, boxSlicedType, boxes, cst0, batchIV);
3808+
3809+
auto nmsClassLoop = rewriter.create<Torch::PrimLoopOp>(
3810+
loc, TypeRange({finalResult.getType(), intTy}), numClasses,
3811+
cstTrue, ValueRange({currRes, finalResIdx}));
3812+
3813+
{
3814+
// Class loop body
3815+
PatternRewriter::InsertionGuard guard(rewriter);
3816+
Block *classLoopBody = rewriter.createBlock(
3817+
&nmsClassLoop.getRegion(), nmsClassLoop.getRegion().begin(),
3818+
TypeRange({intTy, finalResult.getType(), intTy}),
3819+
{loc, loc, loc});
3820+
auto classIV = classLoopBody->getArgument(0);
3821+
auto currRes = classLoopBody->getArgument(1);
3822+
auto finalResIdx = classLoopBody->getArgument(2);
3823+
3824+
auto scoreSelect = rewriter.create<Torch::AtenSelectIntOp>(
3825+
loc, scoreSlicedType, scores, cst0, batchIV);
3826+
3827+
auto scoreSelectType =
3828+
dyn_cast<Torch::ValueTensorType>(scoreSelect.getType());
3829+
assert(scoreSelectType);
3830+
auto scoreValueType = rewriter.getType<Torch::ValueTensorType>(
3831+
scoreSelectType.getSizes().slice(1),
3832+
scoreSelectType.getDtype());
3833+
3834+
auto scoreValue = rewriter.create<Torch::AtenSelectIntOp>(
3835+
loc, scoreValueType, scoreSelect, cst0, classIV);
3836+
3837+
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
3838+
loc, nmsTy, boxValue, scoreValue, iouThreshold);
3839+
3840+
// Slice the result if numOutputBoxes (N) >
3841+
// max_output_boxes_per_class
3842+
Value numOutputBoxes =
3843+
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
3844+
Value boxesCond = rewriter.create<Torch::AtenGtIntOp>(
3845+
loc, numOutputBoxes, maxOutputBoxesPerClass);
3846+
3847+
auto nmsResultTy = Torch::ValueTensorType::get(
3848+
binder.op->getContext(), SmallVector<int64_t>{-1},
3849+
rewriter.getIntegerType(64, /*signed=*/true));
3850+
auto ifSlice = rewriter.create<Torch::PrimIfOp>(
3851+
loc, TypeRange({nmsResultTy}), boxesCond);
3852+
{
3853+
PatternRewriter::InsertionGuard guard(rewriter);
3854+
rewriter.createBlock(&ifSlice.getThenRegion(),
3855+
ifSlice.getThenRegion().begin());
3856+
3857+
Value curResult = rewriter.create<Torch::AtenSliceTensorOp>(
3858+
loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0,
3859+
/*end=*/maxOutputBoxesPerClass, /*step=*/cst1);
3860+
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
3861+
}
3862+
{
3863+
PatternRewriter::InsertionGuard guard(rewriter);
3864+
rewriter.createBlock(&ifSlice.getElseRegion(),
3865+
ifSlice.getElseRegion().begin());
3866+
3867+
Value curResult = rewriter.create<Torch::TensorStaticInfoCastOp>(
3868+
loc, nmsResultTy, result);
3869+
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
3870+
}
3871+
result = ifSlice.getResult(0);
3872+
3873+
// The result generated by torchvision.nms op is of shape [n], while
3874+
// the onnx expects it to be of shape [n, 3]. Hence, we unsqueeze
3875+
// the tensor and make it of shape [n, 1] and then concatenate it
3876+
// with batch and class values to make it shape [n, 3].
3877+
FailureOr<Value> unsqueezedResult =
3878+
Torch::unsqueezeTensor(rewriter, binder.op, result, cst1);
3879+
if (failed(unsqueezedResult))
3880+
return rewriter.notifyMatchFailure(
3881+
binder.op, "failed to unsqueeze result tensor");
3882+
result = unsqueezedResult.value();
3883+
3884+
auto resultNmsType = cast<Torch::ValueTensorType>(result.getType());
3885+
numOutputBoxes =
3886+
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
3887+
3888+
Value catList = rewriter.create<Torch::PrimListConstructOp>(
3889+
loc, intListTy, SmallVector<Value>{numOutputBoxes, cst1});
3890+
3891+
Value resB = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
3892+
loc, resultNmsType, catList, /*dtype=*/cst4,
3893+
/*layout=*/cstNone,
3894+
/*device=*/cstNone, /*pinMemory=*/cstNone,
3895+
/*memoryFormat=*/cstNone);
3896+
auto bVal = rewriter.create<Torch::AtenFillScalarOp>(
3897+
loc, resultNmsType, resB, batchIV);
3898+
3899+
Value resC = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
3900+
loc, resultNmsType, catList, /*dtype=*/cst4,
3901+
/*layout=*/cstNone,
3902+
/*device=*/cstNone, /*pinMemory=*/cstNone,
3903+
/*memoryFormat=*/cstNone);
3904+
auto cVal = rewriter.create<Torch::AtenFillScalarOp>(
3905+
loc, resultNmsType, resC, classIV);
3906+
3907+
Type listElemType =
3908+
cast<Torch::BaseTensorType>(resultType)
3909+
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
3910+
/*optionalDtype=*/nullptr);
3911+
Type listType = Torch::ListType::get(listElemType);
3912+
3913+
Value classResList = rewriter.create<Torch::PrimListConstructOp>(
3914+
loc, listType, SmallVector<Value>{cVal, result});
3915+
3916+
auto cat1Type = rewriter.getType<Torch::ValueTensorType>(
3917+
ArrayRef<int64_t>{-1, 2}, resultNmsType.getDtype());
3918+
3919+
auto cat1 = rewriter.create<Torch::AtenCatOp>(loc, cat1Type,
3920+
classResList, cst1);
3921+
3922+
auto cat2Type = rewriter.getType<Torch::ValueTensorType>(
3923+
SmallVector<int64_t>{-1, 3}, resultNmsType.getDtype());
3924+
3925+
Value batchClassResList =
3926+
rewriter.create<Torch::PrimListConstructOp>(
3927+
loc, listType, SmallVector<Value>{bVal, cat1});
3928+
auto cat2 = rewriter.create<Torch::AtenCatOp>(
3929+
loc, cat2Type, batchClassResList, cst1);
3930+
3931+
Value appendRes = rewriter.create<Torch::PrimListConstructOp>(
3932+
loc, listType, SmallVector<Value>{currRes, cat2});
3933+
auto catResult = rewriter.create<Torch::AtenCatOp>(loc, cat2Type,
3934+
appendRes, cst0);
3935+
Value next =
3936+
rewriter.create<Torch::AtenAddIntOp>(loc, finalResIdx, cst1);
3937+
rewriter.create<Torch::PrimLoopConditionOp>(
3938+
loc, cstTrue, ValueRange({catResult, next}));
3939+
}
3940+
rewriter.create<Torch::PrimLoopConditionOp>(
3941+
loc, cstTrue,
3942+
ValueRange(
3943+
{nmsClassLoop.getResult(0), nmsClassLoop.getResult(1)}));
37993944
}
3800-
{
3801-
PatternRewriter::InsertionGuard guard(rewriter);
3802-
rewriter.createBlock(&ifSlice.getElseRegion(),
3803-
ifSlice.getElseRegion().begin());
3804-
3805-
Value curResult = rewriter.create<Torch::TensorStaticInfoCastOp>(
3806-
loc, nmsResultTy, result);
3807-
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
3808-
}
3809-
result = ifSlice.getResult(0);
3810-
3811-
// The result generated by torchvision.nms op is of shape [n], while the
3812-
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
3813-
// and make it of shape [n, 1] and then concatenate it with a zero
3814-
// tensor of shape [n, 2] to make it of shape [n, 3].
3815-
FailureOr<Value> unsqueezedResult =
3816-
Torch::unsqueezeTensor(rewriter, binder.op, result, cst1);
3817-
if (failed(unsqueezedResult))
3818-
return rewriter.notifyMatchFailure(
3819-
binder.op, "failed to unsqueeze result tensor");
3820-
result = unsqueezedResult.value();
3821-
3822-
numOutputBoxes =
3823-
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
3824-
SmallVector<Value> zerosShapeValues{numOutputBoxes};
3825-
zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
3826-
loc, rewriter.getI64IntegerAttr(2)));
3827-
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
3828-
loc,
3829-
rewriter.getType<Torch::ListType>(
3830-
rewriter.getType<Torch::IntType>()),
3831-
zerosShapeValues);
3832-
std::optional<ArrayRef<int64_t>> resultShape =
3833-
cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
3834-
if (!resultShape.has_value())
3835-
return rewriter.notifyMatchFailure(
3836-
binder.op, "expected result tensor to have shape");
3837-
llvm::SmallVector<int64_t> zerosShape = {resultShape->front(), 2};
3838-
auto zerosTy = Torch::ValueTensorType::get(
3839-
resultType.getContext(), zerosShape, resultType.getOptionalDtype());
3840-
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
3841-
Value zeros = rewriter.create<Torch::AtenZerosOp>(
3842-
loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone);
3843-
3844-
Type listElemType =
3845-
cast<Torch::BaseTensorType>(resultType)
3846-
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
3847-
/*optionalDtype=*/nullptr);
3848-
Type listType = Torch::ListType::get(listElemType);
3849-
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
3850-
loc, listType, SmallVector<Value>{zeros, result});
3851-
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
3852-
tensorList, cst1);
3945+
3946+
rewriter.replaceOpWithNewOp<Torch::TensorStaticInfoCastOp>(
3947+
binder.op, resultType, nmsBatchLoop.getResult(0));
38533948
return success();
38543949
});
38553950
}

0 commit comments

Comments
 (0)