Skip to content

Support batches and classes for nms lowering using loops #3980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 207 additions & 92 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

using namespace mlir;
Expand Down Expand Up @@ -3703,30 +3704,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, "unimplemented: expected center_point_box "
"attribute value to be 0");

// TODO: Support multiple batches and classes
// Squeeze the boxes and scores tensor.
// In Onnx, the shape of boxes is [BxNx4] while the
// torchvision expects it to be of shape [Nx4]. Similarly, for
// the scores tensor shape in Onnx is [BxCxN] while the
// torchvision expects it to be of shape [N].
Value boxes = operands[0], scores = operands[1];
FailureOr<Value> squeezedBoxes =
Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes);
if (failed(squeezedBoxes))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze boxes tensor");
FailureOr<Value> squeezedScores =
Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores);
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0,
squeezedScores.value());
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
boxes = squeezedBoxes.value();
scores = squeezedScores.value();

// TODO: Support score_threshold input
// Filter out the boxes if the score < score_threshold
Expand All @@ -3750,12 +3732,26 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
"unimplemented: score_threshold should be <= min(scores)"));
}

// Get max_output_boxes_per_class and iou_threshold
Value cst0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value cst2 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2));
Value cst3 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(3));
Value cst4 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(4));

Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(true));
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(false));

Value maxOutputBoxesPerClass = cst0;

// Get max_output_boxes_per_class and iou_threshold
Value iouThreshold = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(0.0));
if (operands.size() > 3 &&
Expand All @@ -3769,87 +3765,206 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
loc, rewriter.getType<Torch::IntType>(), operands[2]);
}

auto boxesTensorType = cast<Torch::ValueTensorType>(boxes.getType());
auto scoreTensorType = cast<Torch::ValueTensorType>(scores.getType());
auto boxSlicedType = rewriter.getType<Torch::ValueTensorType>(
boxesTensorType.getSizes().slice(1), boxesTensorType.getDtype());
auto scoreSlicedType = rewriter.getType<Torch::ValueTensorType>(
scoreTensorType.getSizes().slice(1), scoreTensorType.getDtype());

auto numBatches =
rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst0);
auto numClasses =
rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst1);
// auto numBoxes =
// rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst2);

std::optional<ArrayRef<int64_t>> resultShape =
cast<Torch::ValueTensorType>(resultType).getOptionalSizes();
if (!resultShape.has_value())
return rewriter.notifyMatchFailure(
binder.op, "Expected result tensor to have shape");

Value numResults = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(resultShape->front()));

auto intTy = rewriter.getType<Torch::IntType>();
auto intListTy = rewriter.getType<Torch::ListType>(intTy);

Value resultShapeList = rewriter.create<Torch::PrimListConstructOp>(
loc, intListTy, SmallVector<Value>{numResults, cst3});

Value finalResult = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
loc, resultType, resultShapeList, /*dtype=*/cst4,
/*layout=*/cstNone,
/*device=*/cstNone, /*pinMemory=*/cstNone,
/*memoryFormat=*/cstNone);

auto nmsTy = Torch::ValueTensorType::get(
binder.op->getContext(), SmallVector<int64_t>{-1},
rewriter.getIntegerType(64, /*signed=*/true));
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
loc, nmsTy, boxes, scores, iouThreshold);

// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
Value numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
Value boxesCond = rewriter.create<Torch::AtenGtIntOp>(
loc, numOutputBoxes, maxOutputBoxesPerClass);
auto emptyTensorTy = rewriter.getType<Torch::ValueTensorType>(
SmallVector<int64_t>{}, nmsTy.getDtype());

auto nmsResultTy = Torch::ValueTensorType::get(
binder.op->getContext(),
SmallVector<int64_t>{resultType.getSizes()[0]},
rewriter.getIntegerType(64, /*signed=*/true));
auto ifSlice = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({nmsResultTy}), boxesCond);
auto nmsBatchLoop = rewriter.create<Torch::PrimLoopOp>(
loc, TypeRange({resultType, intTy}), numBatches, cstTrue,
ValueRange({finalResult, /*Index to finalResult*/ cst0}));
{
// Batch loop body
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifSlice.getThenRegion(),
ifSlice.getThenRegion().begin());
Block *batchLoopBody = rewriter.createBlock(
&nmsBatchLoop.getRegion(), nmsBatchLoop.getRegion().begin(),
TypeRange({intTy, resultType, intTy}), {loc, loc, loc});
auto batchIV = batchLoopBody->getArgument(0);
auto currRes = batchLoopBody->getArgument(1);
auto finalResIdx = batchLoopBody->getArgument(2);

auto boxValue = rewriter.create<Torch::AtenSelectIntOp>(
loc, boxSlicedType, boxes, cst0, batchIV);

Value curResult = rewriter.create<Torch::AtenSliceTensorOp>(
loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0,
/*end=*/maxOutputBoxesPerClass, /*step=*/cst1);
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
auto batchValue = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, emptyTensorTy, batchIV);

auto nmsClassLoop = rewriter.create<Torch::PrimLoopOp>(
loc, TypeRange({resultType, intTy}), numClasses, cstTrue,
ValueRange({currRes, finalResIdx}));

{
// Class loop body
PatternRewriter::InsertionGuard guard(rewriter);
Block *classLoopBody = rewriter.createBlock(
&nmsClassLoop.getRegion(), nmsClassLoop.getRegion().begin(),
TypeRange({intTy, resultType, intTy}), {loc, loc, loc});
auto classIV = classLoopBody->getArgument(0);
auto currRes = classLoopBody->getArgument(1);
auto finalResIdx = classLoopBody->getArgument(2);

auto scoreSelect = rewriter.create<Torch::AtenSelectIntOp>(
loc, scoreSlicedType, scores, cst0, batchIV);
auto scoreSelectType =
dyn_cast<Torch::ValueTensorType>(scoreSelect.getType());
auto scoreValueType = rewriter.getType<Torch::ValueTensorType>(
scoreSelectType.getSizes().slice(1),
scoreSelectType.getDtype());

auto scoreValue = rewriter.create<Torch::AtenSelectIntOp>(
loc, scoreValueType, scoreSelect, cst0, classIV);

auto classValue = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, emptyTensorTy, classIV);

Value result = rewriter.create<Torch::TorchvisionNmsOp>(
loc, nmsTy, boxValue, scoreValue, iouThreshold);

Value numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);

numOutputBoxes = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, emptyTensorTy, numOutputBoxes);
Value maxBoxesPerClass =
rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, emptyTensorTy, maxOutputBoxesPerClass);
auto minVal = rewriter.create<Torch::AtenMinimumOp>(
loc, numOutputBoxes.getType(), numOutputBoxes,
maxBoxesPerClass);
numOutputBoxes =
rewriter.create<Torch::AtenItemOp>(loc, intTy, minVal);

// Loop through the nms result
auto nmsLoop = rewriter.create<Torch::PrimLoopOp>(
loc, TypeRange({resultType, intTy}), numOutputBoxes, cstTrue,
ValueRange({currRes, finalResIdx}));
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *loopBody = rewriter.createBlock(
&nmsLoop.getRegion(), nmsLoop.getRegion().begin(),
TypeRange({intTy, resultType, intTy}), {loc, loc, loc});
auto iter = loopBody->getArgument(0);
auto currRes = loopBody->getArgument(1);
auto idxCst = loopBody->getArgument(2);

// Update batch

auto outputTensorSliceType =
rewriter.getType<Torch::ValueTensorType>(
SmallVector<int64_t>{3}, nmsTy.getDtype());

auto batchDim3D = rewriter.create<Torch::AtenSelectIntOp>(
loc, outputTensorSliceType, currRes, cst0, idxCst);

auto batchSelect = rewriter.create<Torch::AtenSelectIntOp>(
loc, emptyTensorTy, batchDim3D, cst0, cst0);

auto bCopy = rewriter.create<Torch::AtenCopyOp>(
loc, batchSelect.getType(), batchSelect, batchValue,
cstFalse);

batchDim3D = rewriter.create<Torch::AtenSelectIntOp>(
loc, outputTensorSliceType, currRes, cst0, idxCst);

auto scatterBatch = rewriter.create<Torch::AtenSelectScatterOp>(
loc, outputTensorSliceType, batchDim3D, bCopy, cst0, cst0);
// Yield result
auto batchResult = rewriter.create<Torch::AtenSelectScatterOp>(
loc, resultType, currRes, scatterBatch, cst0, idxCst);

// Class values
auto classDim3D = rewriter.create<Torch::AtenSelectIntOp>(
loc, outputTensorSliceType, batchResult, cst0, idxCst);
// Class indices
auto classSelect = rewriter.create<Torch::AtenSelectIntOp>(
loc, emptyTensorTy, classDim3D, cst0, cst1);

auto cCopy = rewriter.create<Torch::AtenCopyOp>(
loc, classSelect.getType(), classSelect, classValue,
cstFalse);

classDim3D = rewriter.create<Torch::AtenSelectIntOp>(
loc, outputTensorSliceType, batchResult, cst0, idxCst);

auto scatterClass = rewriter.create<Torch::AtenSelectScatterOp>(
loc, outputTensorSliceType, classDim3D, cCopy, cst0, cst1);
auto classRes = rewriter.create<Torch::AtenSelectScatterOp>(
loc, resultType, batchResult, scatterClass, cst0, idxCst);

auto resDim3D = rewriter.create<Torch::AtenSelectIntOp>(
loc, outputTensorSliceType, classRes, cst0, idxCst);
// Class indices
auto resSelect = rewriter.create<Torch::AtenSelectIntOp>(
loc, emptyTensorTy, resDim3D, cst0, cst2);

auto nmsResultValue = rewriter.create<Torch::AtenSelectIntOp>(
loc, emptyTensorTy, result, cst0, iter);

auto rCopy = rewriter.create<Torch::AtenCopyOp>(
loc, resSelect.getType(), resSelect, nmsResultValue,
cstFalse);

resDim3D = rewriter.create<Torch::AtenSelectIntOp>(
loc, outputTensorSliceType, classRes, cst0, idxCst);

auto scatterRes = rewriter.create<Torch::AtenSelectScatterOp>(
loc, outputTensorSliceType, resDim3D, rCopy, cst0, cst2);
Value nmsResult = rewriter.create<Torch::AtenSelectScatterOp>(
loc, resultType, classRes, scatterRes, cst0, idxCst);

Value next =
rewriter.create<Torch::AtenAddIntOp>(loc, idxCst, cst1);

rewriter.create<Torch::PrimLoopConditionOp>(
loc, cstTrue, ValueRange({nmsResult, next}));
}
rewriter.create<Torch::PrimLoopConditionOp>(
loc, cstTrue,
ValueRange({nmsLoop.getResult(0), nmsLoop.getResult(1)}));
}
rewriter.create<Torch::PrimLoopConditionOp>(
loc, cstTrue,
ValueRange(
{nmsClassLoop.getResult(0), nmsClassLoop.getResult(1)}));
}
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifSlice.getElseRegion(),
ifSlice.getElseRegion().begin());

Value curResult = rewriter.create<Torch::TensorStaticInfoCastOp>(
loc, nmsResultTy, result);
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
}
result = ifSlice.getResult(0);

// The result generated by torchvision.nms op is of shape [n], while the
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
// and make it of shape [n, 1] and then concatenate it with a zero
// tensor of shape [n, 2] to make it of shape [n, 3].
FailureOr<Value> unsqueezedResult =
Torch::unsqueezeTensor(rewriter, binder.op, result, cst1);
if (failed(unsqueezedResult))
return rewriter.notifyMatchFailure(
binder.op, "failed to unsqueeze result tensor");
result = unsqueezedResult.value();

numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
SmallVector<Value> zerosShapeValues{numOutputBoxes};
zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2)));
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
zerosShapeValues);
std::optional<ArrayRef<int64_t>> resultShape =
cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
if (!resultShape.has_value())
return rewriter.notifyMatchFailure(
binder.op, "expected result tensor to have shape");
llvm::SmallVector<int64_t> zerosShape = {resultShape->front(), 2};
auto zerosTy = Torch::ValueTensorType::get(
resultType.getContext(), zerosShape, resultType.getOptionalDtype());
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value zeros = rewriter.create<Torch::AtenZerosOp>(
loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone);

Type listElemType =
cast<Torch::BaseTensorType>(resultType)
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, listType, SmallVector<Value>{zeros, result});
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
tensorList, cst1);
rewriter.replaceOp(binder.op, nmsBatchLoop.getResult(0));
return success();
});
}