9
9
10
10
#include " torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
11
11
#include " torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
12
+ #include " torch-mlir/Dialect/Torch/IR/TorchOps.h"
12
13
#include " torch-mlir/Dialect/Torch/Utils/Utils.h"
13
14
14
15
using namespace mlir ;
@@ -3703,30 +3704,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3703
3704
binder.op , " unimplemented: expected center_point_box "
3704
3705
" attribute value to be 0" );
3705
3706
3706
- // TODO: Support multiple batches and classes
3707
- // Squeeze the boxes and scores tensor.
3708
3707
// In Onnx, the shape of boxes is [BxNx4] while the
3709
3708
// torchvision expects it to be of shape [Nx4]. Similarly, for
3710
3709
// the scores tensor shape in Onnx is [BxCxN] while the
3711
3710
// torchvision expects it to be of shape [N].
3712
3711
Value boxes = operands[0 ], scores = operands[1 ];
3713
- FailureOr<Value> squeezedBoxes =
3714
- Torch::squeezeTensor (rewriter, binder.op , loc, 0 , boxes);
3715
- if (failed (squeezedBoxes))
3716
- return rewriter.notifyMatchFailure (binder.op ,
3717
- " failed to squeeze boxes tensor" );
3718
- FailureOr<Value> squeezedScores =
3719
- Torch::squeezeTensor (rewriter, binder.op , loc, 0 , scores);
3720
- if (failed (squeezedScores))
3721
- return rewriter.notifyMatchFailure (binder.op ,
3722
- " failed to squeeze scores tensor" );
3723
- squeezedScores = Torch::squeezeTensor (rewriter, binder.op , loc, 0 ,
3724
- squeezedScores.value ());
3725
- if (failed (squeezedScores))
3726
- return rewriter.notifyMatchFailure (binder.op ,
3727
- " failed to squeeze scores tensor" );
3728
- boxes = squeezedBoxes.value ();
3729
- scores = squeezedScores.value ();
3730
3712
3731
3713
// TODO: Support score_threshold input
3732
3714
// Filter out the boxes if the score < score_threshold
@@ -3750,12 +3732,26 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3750
3732
" unimplemented: score_threshold should be <= min(scores)" ));
3751
3733
}
3752
3734
3753
- // Get max_output_boxes_per_class and iou_threshold
3754
3735
Value cst0 = rewriter.create <Torch::ConstantIntOp>(
3755
3736
loc, rewriter.getI64IntegerAttr (0 ));
3756
3737
Value cst1 = rewriter.create <Torch::ConstantIntOp>(
3757
3738
loc, rewriter.getI64IntegerAttr (1 ));
3739
+ Value cst2 = rewriter.create <Torch::ConstantIntOp>(
3740
+ loc, rewriter.getI64IntegerAttr (2 ));
3741
+ Value cst3 = rewriter.create <Torch::ConstantIntOp>(
3742
+ loc, rewriter.getI64IntegerAttr (3 ));
3743
+ Value cst4 = rewriter.create <Torch::ConstantIntOp>(
3744
+ loc, rewriter.getI64IntegerAttr (4 ));
3745
+
3746
+ Value cstNone = rewriter.create <Torch::ConstantNoneOp>(loc);
3747
+ Value cstTrue = rewriter.create <Torch::ConstantBoolOp>(
3748
+ loc, rewriter.getBoolAttr (true ));
3749
+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(
3750
+ loc, rewriter.getBoolAttr (false ));
3751
+
3758
3752
Value maxOutputBoxesPerClass = cst0;
3753
+
3754
+ // Get max_output_boxes_per_class and iou_threshold
3759
3755
Value iouThreshold = rewriter.create <Torch::ConstantFloatOp>(
3760
3756
loc, rewriter.getF64FloatAttr (0.0 ));
3761
3757
if (operands.size () > 3 &&
@@ -3769,87 +3765,206 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3769
3765
loc, rewriter.getType <Torch::IntType>(), operands[2 ]);
3770
3766
}
3771
3767
3768
+ auto boxesTensorType = cast<Torch::ValueTensorType>(boxes.getType ());
3769
+ auto scoreTensorType = cast<Torch::ValueTensorType>(scores.getType ());
3770
+ auto boxSlicedType = rewriter.getType <Torch::ValueTensorType>(
3771
+ boxesTensorType.getSizes ().slice (1 ), boxesTensorType.getDtype ());
3772
+ auto scoreSlicedType = rewriter.getType <Torch::ValueTensorType>(
3773
+ scoreTensorType.getSizes ().slice (1 ), scoreTensorType.getDtype ());
3774
+
3775
+ auto numBatches =
3776
+ rewriter.create <Torch::AtenSizeIntOp>(loc, scores, cst0);
3777
+ auto numClasses =
3778
+ rewriter.create <Torch::AtenSizeIntOp>(loc, scores, cst1);
3779
+ // auto numBoxes =
3780
+ // rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst2);
3781
+
3782
+ std::optional<ArrayRef<int64_t >> resultShape =
3783
+ cast<Torch::ValueTensorType>(resultType).getOptionalSizes ();
3784
+ if (!resultShape.has_value ())
3785
+ return rewriter.notifyMatchFailure (
3786
+ binder.op , " Expected result tensor to have shape" );
3787
+
3788
+ Value numResults = rewriter.create <Torch::ConstantIntOp>(
3789
+ loc, rewriter.getI64IntegerAttr (resultShape->front ()));
3790
+
3791
+ auto intTy = rewriter.getType <Torch::IntType>();
3792
+ auto intListTy = rewriter.getType <Torch::ListType>(intTy);
3793
+
3794
+ Value resultShapeList = rewriter.create <Torch::PrimListConstructOp>(
3795
+ loc, intListTy, SmallVector<Value>{numResults, cst3});
3796
+
3797
+ Value finalResult = rewriter.create <Torch::AtenEmptyMemoryFormatOp>(
3798
+ loc, resultType, resultShapeList, /* dtype=*/ cst4,
3799
+ /* layout=*/ cstNone,
3800
+ /* device=*/ cstNone, /* pinMemory=*/ cstNone,
3801
+ /* memoryFormat=*/ cstNone);
3802
+
3772
3803
auto nmsTy = Torch::ValueTensorType::get (
3773
3804
binder.op ->getContext (), SmallVector<int64_t >{-1 },
3774
3805
rewriter.getIntegerType (64 , /* signed=*/ true ));
3775
- Value result = rewriter.create <Torch::TorchvisionNmsOp>(
3776
- loc, nmsTy, boxes, scores, iouThreshold);
3777
3806
3778
- // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
3779
- Value numOutputBoxes =
3780
- rewriter.create <Torch::AtenSizeIntOp>(loc, result, cst0);
3781
- Value boxesCond = rewriter.create <Torch::AtenGtIntOp>(
3782
- loc, numOutputBoxes, maxOutputBoxesPerClass);
3807
+ auto emptyTensorTy = rewriter.getType <Torch::ValueTensorType>(
3808
+ SmallVector<int64_t >{}, nmsTy.getDtype ());
3783
3809
3784
- auto nmsResultTy = Torch::ValueTensorType::get (
3785
- binder.op ->getContext (),
3786
- SmallVector<int64_t >{resultType.getSizes ()[0 ]},
3787
- rewriter.getIntegerType (64 , /* signed=*/ true ));
3788
- auto ifSlice = rewriter.create <Torch::PrimIfOp>(
3789
- loc, TypeRange ({nmsResultTy}), boxesCond);
3810
+ auto nmsBatchLoop = rewriter.create <Torch::PrimLoopOp>(
3811
+ loc, TypeRange ({resultType, intTy}), numBatches, cstTrue,
3812
+ ValueRange ({finalResult, /* Index to finalResult*/ cst0}));
3790
3813
{
3814
+ // Batch loop body
3791
3815
PatternRewriter::InsertionGuard guard (rewriter);
3792
- rewriter.createBlock (&ifSlice.getThenRegion (),
3793
- ifSlice.getThenRegion ().begin ());
3816
+ Block *batchLoopBody = rewriter.createBlock (
3817
+ &nmsBatchLoop.getRegion (), nmsBatchLoop.getRegion ().begin (),
3818
+ TypeRange ({intTy, resultType, intTy}), {loc, loc, loc});
3819
+ auto batchIV = batchLoopBody->getArgument (0 );
3820
+ auto currRes = batchLoopBody->getArgument (1 );
3821
+ auto finalResIdx = batchLoopBody->getArgument (2 );
3822
+
3823
+ auto boxValue = rewriter.create <Torch::AtenSelectIntOp>(
3824
+ loc, boxSlicedType, boxes, cst0, batchIV);
3794
3825
3795
- Value curResult = rewriter.create <Torch::AtenSliceTensorOp>(
3796
- loc, nmsResultTy, result, /* dim=*/ cst0, /* start=*/ cst0,
3797
- /* end=*/ maxOutputBoxesPerClass, /* step=*/ cst1);
3798
- rewriter.create <Torch::PrimIfYieldOp>(loc, curResult);
3826
+ auto batchValue = rewriter.create <Torch::PrimNumToTensorScalarOp>(
3827
+ loc, emptyTensorTy, batchIV);
3828
+
3829
+ auto nmsClassLoop = rewriter.create <Torch::PrimLoopOp>(
3830
+ loc, TypeRange ({resultType, intTy}), numClasses, cstTrue,
3831
+ ValueRange ({currRes, finalResIdx}));
3832
+
3833
+ {
3834
+ // Class loop body
3835
+ PatternRewriter::InsertionGuard guard (rewriter);
3836
+ Block *classLoopBody = rewriter.createBlock (
3837
+ &nmsClassLoop.getRegion (), nmsClassLoop.getRegion ().begin (),
3838
+ TypeRange ({intTy, resultType, intTy}), {loc, loc, loc});
3839
+ auto classIV = classLoopBody->getArgument (0 );
3840
+ auto currRes = classLoopBody->getArgument (1 );
3841
+ auto finalResIdx = classLoopBody->getArgument (2 );
3842
+
3843
+ auto scoreSelect = rewriter.create <Torch::AtenSelectIntOp>(
3844
+ loc, scoreSlicedType, scores, cst0, batchIV);
3845
+ auto scoreSelectType =
3846
+ dyn_cast<Torch::ValueTensorType>(scoreSelect.getType ());
3847
+ auto scoreValueType = rewriter.getType <Torch::ValueTensorType>(
3848
+ scoreSelectType.getSizes ().slice (1 ),
3849
+ scoreSelectType.getDtype ());
3850
+
3851
+ auto scoreValue = rewriter.create <Torch::AtenSelectIntOp>(
3852
+ loc, scoreValueType, scoreSelect, cst0, classIV);
3853
+
3854
+ auto classValue = rewriter.create <Torch::PrimNumToTensorScalarOp>(
3855
+ loc, emptyTensorTy, classIV);
3856
+
3857
+ Value result = rewriter.create <Torch::TorchvisionNmsOp>(
3858
+ loc, nmsTy, boxValue, scoreValue, iouThreshold);
3859
+
3860
+ Value numOutputBoxes =
3861
+ rewriter.create <Torch::AtenSizeIntOp>(loc, result, cst0);
3862
+
3863
+ numOutputBoxes = rewriter.create <Torch::PrimNumToTensorScalarOp>(
3864
+ loc, emptyTensorTy, numOutputBoxes);
3865
+ Value maxBoxesPerClass =
3866
+ rewriter.create <Torch::PrimNumToTensorScalarOp>(
3867
+ loc, emptyTensorTy, maxOutputBoxesPerClass);
3868
+ auto minVal = rewriter.create <Torch::AtenMinimumOp>(
3869
+ loc, numOutputBoxes.getType (), numOutputBoxes,
3870
+ maxBoxesPerClass);
3871
+ numOutputBoxes =
3872
+ rewriter.create <Torch::AtenItemOp>(loc, intTy, minVal);
3873
+
3874
+ // Loop through the nms result
3875
+ auto nmsLoop = rewriter.create <Torch::PrimLoopOp>(
3876
+ loc, TypeRange ({resultType, intTy}), numOutputBoxes, cstTrue,
3877
+ ValueRange ({currRes, finalResIdx}));
3878
+ {
3879
+ PatternRewriter::InsertionGuard guard (rewriter);
3880
+ Block *loopBody = rewriter.createBlock (
3881
+ &nmsLoop.getRegion (), nmsLoop.getRegion ().begin (),
3882
+ TypeRange ({intTy, resultType, intTy}), {loc, loc, loc});
3883
+ auto iter = loopBody->getArgument (0 );
3884
+ auto currRes = loopBody->getArgument (1 );
3885
+ auto idxCst = loopBody->getArgument (2 );
3886
+
3887
+ // Update batch
3888
+
3889
+ auto outputTensorSliceType =
3890
+ rewriter.getType <Torch::ValueTensorType>(
3891
+ SmallVector<int64_t >{3 }, nmsTy.getDtype ());
3892
+
3893
+ auto batchDim3D = rewriter.create <Torch::AtenSelectIntOp>(
3894
+ loc, outputTensorSliceType, currRes, cst0, idxCst);
3895
+
3896
+ auto batchSelect = rewriter.create <Torch::AtenSelectIntOp>(
3897
+ loc, emptyTensorTy, batchDim3D, cst0, cst0);
3898
+
3899
+ auto bCopy = rewriter.create <Torch::AtenCopyOp>(
3900
+ loc, batchSelect.getType (), batchSelect, batchValue,
3901
+ cstFalse);
3902
+
3903
+ batchDim3D = rewriter.create <Torch::AtenSelectIntOp>(
3904
+ loc, outputTensorSliceType, currRes, cst0, idxCst);
3905
+
3906
+ auto scatterBatch = rewriter.create <Torch::AtenSelectScatterOp>(
3907
+ loc, outputTensorSliceType, batchDim3D, bCopy, cst0, cst0);
3908
+ // Yield result
3909
+ auto batchResult = rewriter.create <Torch::AtenSelectScatterOp>(
3910
+ loc, resultType, currRes, scatterBatch, cst0, idxCst);
3911
+
3912
+ // Class values
3913
+ auto classDim3D = rewriter.create <Torch::AtenSelectIntOp>(
3914
+ loc, outputTensorSliceType, batchResult, cst0, idxCst);
3915
+ // Class indices
3916
+ auto classSelect = rewriter.create <Torch::AtenSelectIntOp>(
3917
+ loc, emptyTensorTy, classDim3D, cst0, cst1);
3918
+
3919
+ auto cCopy = rewriter.create <Torch::AtenCopyOp>(
3920
+ loc, classSelect.getType (), classSelect, classValue,
3921
+ cstFalse);
3922
+
3923
+ classDim3D = rewriter.create <Torch::AtenSelectIntOp>(
3924
+ loc, outputTensorSliceType, batchResult, cst0, idxCst);
3925
+
3926
+ auto scatterClass = rewriter.create <Torch::AtenSelectScatterOp>(
3927
+ loc, outputTensorSliceType, classDim3D, cCopy, cst0, cst1);
3928
+ auto classRes = rewriter.create <Torch::AtenSelectScatterOp>(
3929
+ loc, resultType, batchResult, scatterClass, cst0, idxCst);
3930
+
3931
+ auto resDim3D = rewriter.create <Torch::AtenSelectIntOp>(
3932
+ loc, outputTensorSliceType, classRes, cst0, idxCst);
3933
+ // Class indices
3934
+ auto resSelect = rewriter.create <Torch::AtenSelectIntOp>(
3935
+ loc, emptyTensorTy, resDim3D, cst0, cst2);
3936
+
3937
+ auto nmsResultValue = rewriter.create <Torch::AtenSelectIntOp>(
3938
+ loc, emptyTensorTy, result, cst0, iter);
3939
+
3940
+ auto rCopy = rewriter.create <Torch::AtenCopyOp>(
3941
+ loc, resSelect.getType (), resSelect, nmsResultValue,
3942
+ cstFalse);
3943
+
3944
+ resDim3D = rewriter.create <Torch::AtenSelectIntOp>(
3945
+ loc, outputTensorSliceType, classRes, cst0, idxCst);
3946
+
3947
+ auto scatterRes = rewriter.create <Torch::AtenSelectScatterOp>(
3948
+ loc, outputTensorSliceType, resDim3D, rCopy, cst0, cst2);
3949
+ Value nmsResult = rewriter.create <Torch::AtenSelectScatterOp>(
3950
+ loc, resultType, classRes, scatterRes, cst0, idxCst);
3951
+
3952
+ Value next =
3953
+ rewriter.create <Torch::AtenAddIntOp>(loc, idxCst, cst1);
3954
+
3955
+ rewriter.create <Torch::PrimLoopConditionOp>(
3956
+ loc, cstTrue, ValueRange ({nmsResult, next}));
3957
+ }
3958
+ rewriter.create <Torch::PrimLoopConditionOp>(
3959
+ loc, cstTrue,
3960
+ ValueRange ({nmsLoop.getResult (0 ), nmsLoop.getResult (1 )}));
3961
+ }
3962
+ rewriter.create <Torch::PrimLoopConditionOp>(
3963
+ loc, cstTrue,
3964
+ ValueRange (
3965
+ {nmsClassLoop.getResult (0 ), nmsClassLoop.getResult (1 )}));
3799
3966
}
3800
- {
3801
- PatternRewriter::InsertionGuard guard (rewriter);
3802
- rewriter.createBlock (&ifSlice.getElseRegion (),
3803
- ifSlice.getElseRegion ().begin ());
3804
-
3805
- Value curResult = rewriter.create <Torch::TensorStaticInfoCastOp>(
3806
- loc, nmsResultTy, result);
3807
- rewriter.create <Torch::PrimIfYieldOp>(loc, curResult);
3808
- }
3809
- result = ifSlice.getResult (0 );
3810
-
3811
- // The result generated by torchvision.nms op is of shape [n], while the
3812
- // onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
3813
- // and make it of shape [n, 1] and then concatenate it with a zero
3814
- // tensor of shape [n, 2] to make it of shape [n, 3].
3815
- FailureOr<Value> unsqueezedResult =
3816
- Torch::unsqueezeTensor (rewriter, binder.op , result, cst1);
3817
- if (failed (unsqueezedResult))
3818
- return rewriter.notifyMatchFailure (
3819
- binder.op , " failed to unsqueeze result tensor" );
3820
- result = unsqueezedResult.value ();
3821
-
3822
- numOutputBoxes =
3823
- rewriter.create <Torch::AtenSizeIntOp>(loc, result, cst0);
3824
- SmallVector<Value> zerosShapeValues{numOutputBoxes};
3825
- zerosShapeValues.push_back (rewriter.create <Torch::ConstantIntOp>(
3826
- loc, rewriter.getI64IntegerAttr (2 )));
3827
- Value zerosShapeList = rewriter.create <Torch::PrimListConstructOp>(
3828
- loc,
3829
- rewriter.getType <Torch::ListType>(
3830
- rewriter.getType <Torch::IntType>()),
3831
- zerosShapeValues);
3832
- std::optional<ArrayRef<int64_t >> resultShape =
3833
- cast<Torch::ValueTensorType>(result.getType ()).getOptionalSizes ();
3834
- if (!resultShape.has_value ())
3835
- return rewriter.notifyMatchFailure (
3836
- binder.op , " expected result tensor to have shape" );
3837
- llvm::SmallVector<int64_t > zerosShape = {resultShape->front (), 2 };
3838
- auto zerosTy = Torch::ValueTensorType::get (
3839
- resultType.getContext (), zerosShape, resultType.getOptionalDtype ());
3840
- Value cstNone = rewriter.create <Torch::ConstantNoneOp>(loc);
3841
- Value zeros = rewriter.create <Torch::AtenZerosOp>(
3842
- loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone);
3843
-
3844
- Type listElemType =
3845
- cast<Torch::BaseTensorType>(resultType)
3846
- .getWithSizesAndDtype (/* optionalSizes=*/ std::nullopt,
3847
- /* optionalDtype=*/ nullptr );
3848
- Type listType = Torch::ListType::get (listElemType);
3849
- Value tensorList = rewriter.create <Torch::PrimListConstructOp>(
3850
- loc, listType, SmallVector<Value>{zeros, result});
3851
- rewriter.replaceOpWithNewOp <Torch::AtenCatOp>(binder.op , resultType,
3852
- tensorList, cst1);
3967
+ rewriter.replaceOp (binder.op , nmsBatchLoop.getResult (0 ));
3853
3968
return success ();
3854
3969
});
3855
3970
}
0 commit comments