@@ -3703,30 +3703,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3703
3703
binder.op , " unimplemented: expected center_point_box "
3704
3704
" attribute value to be 0" );
3705
3705
3706
- // TODO: Support multiple batches and classes
3707
- // Squeeze the boxes and scores tensor.
3708
3706
// In Onnx, the shape of boxes is [BxNx4] while the
3709
3707
// torchvision expects it to be of shape [Nx4]. Similarly, for
3710
3708
// the scores tensor shape in Onnx is [BxCxN] while the
3711
3709
// torchvision expects it to be of shape [N].
3712
3710
Value boxes = operands[0 ], scores = operands[1 ];
3713
- FailureOr<Value> squeezedBoxes =
3714
- Torch::squeezeTensor (rewriter, binder.op , loc, 0 , boxes);
3715
- if (failed (squeezedBoxes))
3716
- return rewriter.notifyMatchFailure (binder.op ,
3717
- " failed to squeeze boxes tensor" );
3718
- FailureOr<Value> squeezedScores =
3719
- Torch::squeezeTensor (rewriter, binder.op , loc, 0 , scores);
3720
- if (failed (squeezedScores))
3721
- return rewriter.notifyMatchFailure (binder.op ,
3722
- " failed to squeeze scores tensor" );
3723
- squeezedScores = Torch::squeezeTensor (rewriter, binder.op , loc, 0 ,
3724
- squeezedScores.value ());
3725
- if (failed (squeezedScores))
3726
- return rewriter.notifyMatchFailure (binder.op ,
3727
- " failed to squeeze scores tensor" );
3728
- boxes = squeezedBoxes.value ();
3729
- scores = squeezedScores.value ();
3730
3711
3731
3712
// TODO: Support score_threshold input
3732
3713
// Filter out the boxes if the score < score_threshold
@@ -3755,6 +3736,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3755
3736
loc, rewriter.getI64IntegerAttr (0 ));
3756
3737
Value cst1 = rewriter.create <Torch::ConstantIntOp>(
3757
3738
loc, rewriter.getI64IntegerAttr (1 ));
3739
+ Value cst3 = rewriter.create <Torch::ConstantIntOp>(
3740
+ loc, rewriter.getI64IntegerAttr (3 ));
3741
+ Value cst4 = rewriter.create <Torch::ConstantIntOp>(
3742
+ loc, rewriter.getI64IntegerAttr (4 ));
3743
+ Value cstNone = rewriter.create <Torch::ConstantNoneOp>(loc);
3744
+
3758
3745
Value maxOutputBoxesPerClass = cst0;
3759
3746
Value iouThreshold = rewriter.create <Torch::ConstantFloatOp>(
3760
3747
loc, rewriter.getF64FloatAttr (0.0 ));
@@ -3769,87 +3756,195 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3769
3756
loc, rewriter.getType <Torch::IntType>(), operands[2 ]);
3770
3757
}
3771
3758
3759
+ auto boxTensorType = cast<Torch::ValueTensorType>(boxes.getType ());
3760
+ auto boxSlicedType = rewriter.getType <Torch::ValueTensorType>(
3761
+ boxTensorType.getSizes ().slice (1 ), boxTensorType.getDtype ());
3762
+ auto scoreTensorType = cast<Torch::ValueTensorType>(scores.getType ());
3763
+ auto scoreSlicedType = rewriter.getType <Torch::ValueTensorType>(
3764
+ scoreTensorType.getSizes ().slice (1 ), scoreTensorType.getDtype ());
3765
+
3766
+ auto intListTy = rewriter.getType <Torch::ListType>(
3767
+ rewriter.getType <Torch::IntType>());
3768
+ Value lenShapeList = rewriter.create <Torch::PrimListConstructOp>(
3769
+ loc, intListTy, SmallVector<Value>{cst0, cst3});
3770
+
3771
+ auto emptyResType = rewriter.getType <Torch::ValueTensorType>(
3772
+ ArrayRef<int64_t >{-1 , 3 }, resultType.getDtype ());
3773
+ Value finalResult = rewriter.create <Torch::AtenEmptyMemoryFormatOp>(
3774
+ loc, emptyResType, lenShapeList, /* dtype=*/ cst4, /* layout=*/ cstNone,
3775
+ /* device=*/ cstNone, /* pinMemory=*/ cstNone,
3776
+ /* memoryFormat=*/ cstNone);
3777
+
3772
3778
auto nmsTy = Torch::ValueTensorType::get (
3773
3779
binder.op ->getContext (), SmallVector<int64_t >{-1 },
3774
3780
rewriter.getIntegerType (64 , /* signed=*/ true ));
3775
- Value result = rewriter.create <Torch::TorchvisionNmsOp>(
3776
- loc, nmsTy, boxes, scores, iouThreshold);
3777
3781
3778
- // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
3779
- Value numOutputBoxes =
3780
- rewriter.create <Torch::AtenSizeIntOp>(loc, result, cst0);
3781
- Value boxesCond = rewriter.create <Torch::AtenGtIntOp>(
3782
- loc, numOutputBoxes, maxOutputBoxesPerClass);
3782
+ auto numBatches =
3783
+ rewriter.create <Torch::AtenSizeIntOp>(loc, scores, cst0);
3784
+ auto numClasses =
3785
+ rewriter.create <Torch::AtenSizeIntOp>(loc, scores, cst1);
3786
+ Value cstTrue = rewriter.create <Torch::ConstantBoolOp>(
3787
+ loc, rewriter.getBoolAttr (true ));
3788
+ auto intTy = rewriter.getType <Torch::IntType>();
3789
+
3790
+ auto nmsBatchLoop = rewriter.create <Torch::PrimLoopOp>(
3791
+ loc, TypeRange ({finalResult.getType (), intTy}), numBatches, cstTrue,
3792
+ ValueRange ({finalResult, /* Index to finalResult*/ cst0}));
3783
3793
3784
- auto nmsResultTy = Torch::ValueTensorType::get (
3785
- binder.op ->getContext (),
3786
- SmallVector<int64_t >{resultType.getSizes ()[0 ]},
3787
- rewriter.getIntegerType (64 , /* signed=*/ true ));
3788
- auto ifSlice = rewriter.create <Torch::PrimIfOp>(
3789
- loc, TypeRange ({nmsResultTy}), boxesCond);
3790
3794
{
3795
+
3796
+ // Batch loop body
3791
3797
PatternRewriter::InsertionGuard guard (rewriter);
3792
- rewriter.createBlock (&ifSlice.getThenRegion (),
3793
- ifSlice.getThenRegion ().begin ());
3798
+ Block *batchLoopBody = rewriter.createBlock (
3799
+ &nmsBatchLoop.getRegion (), nmsBatchLoop.getRegion ().begin (),
3800
+ TypeRange ({intTy, finalResult.getType (), intTy}),
3801
+ {loc, loc, loc});
3802
+ auto batchIV = batchLoopBody->getArgument (0 );
3803
+ auto currRes = batchLoopBody->getArgument (1 );
3804
+ auto finalResIdx = batchLoopBody->getArgument (2 );
3794
3805
3795
- Value curResult = rewriter.create <Torch::AtenSliceTensorOp>(
3796
- loc, nmsResultTy, result, /* dim=*/ cst0, /* start=*/ cst0,
3797
- /* end=*/ maxOutputBoxesPerClass, /* step=*/ cst1);
3798
- rewriter.create <Torch::PrimIfYieldOp>(loc, curResult);
3806
+ auto boxValue = rewriter.create <Torch::AtenSelectIntOp>(
3807
+ loc, boxSlicedType, boxes, cst0, batchIV);
3808
+
3809
+ auto nmsClassLoop = rewriter.create <Torch::PrimLoopOp>(
3810
+ loc, TypeRange ({finalResult.getType (), intTy}), numClasses,
3811
+ cstTrue, ValueRange ({currRes, finalResIdx}));
3812
+
3813
+ {
3814
+ // Class loop body
3815
+ PatternRewriter::InsertionGuard guard (rewriter);
3816
+ Block *classLoopBody = rewriter.createBlock (
3817
+ &nmsClassLoop.getRegion (), nmsClassLoop.getRegion ().begin (),
3818
+ TypeRange ({intTy, finalResult.getType (), intTy}),
3819
+ {loc, loc, loc});
3820
+ auto classIV = classLoopBody->getArgument (0 );
3821
+ auto currRes = classLoopBody->getArgument (1 );
3822
+ auto finalResIdx = classLoopBody->getArgument (2 );
3823
+
3824
+ auto scoreSelect = rewriter.create <Torch::AtenSelectIntOp>(
3825
+ loc, scoreSlicedType, scores, cst0, batchIV);
3826
+
3827
+ auto scoreSelectType =
3828
+ dyn_cast<Torch::ValueTensorType>(scoreSelect.getType ());
3829
+ assert (scoreSelectType);
3830
+ auto scoreValueType = rewriter.getType <Torch::ValueTensorType>(
3831
+ scoreSelectType.getSizes ().slice (1 ),
3832
+ scoreSelectType.getDtype ());
3833
+
3834
+ auto scoreValue = rewriter.create <Torch::AtenSelectIntOp>(
3835
+ loc, scoreValueType, scoreSelect, cst0, classIV);
3836
+
3837
+ Value result = rewriter.create <Torch::TorchvisionNmsOp>(
3838
+ loc, nmsTy, boxValue, scoreValue, iouThreshold);
3839
+
3840
+ // Slice the result if numOutputBoxes (N) >
3841
+ // max_output_boxes_per_class
3842
+ Value numOutputBoxes =
3843
+ rewriter.create <Torch::AtenSizeIntOp>(loc, result, cst0);
3844
+ Value boxesCond = rewriter.create <Torch::AtenGtIntOp>(
3845
+ loc, numOutputBoxes, maxOutputBoxesPerClass);
3846
+
3847
+ auto nmsResultTy = Torch::ValueTensorType::get (
3848
+ binder.op ->getContext (), SmallVector<int64_t >{-1 },
3849
+ rewriter.getIntegerType (64 , /* signed=*/ true ));
3850
+ auto ifSlice = rewriter.create <Torch::PrimIfOp>(
3851
+ loc, TypeRange ({nmsResultTy}), boxesCond);
3852
+ {
3853
+ PatternRewriter::InsertionGuard guard (rewriter);
3854
+ rewriter.createBlock (&ifSlice.getThenRegion (),
3855
+ ifSlice.getThenRegion ().begin ());
3856
+
3857
+ Value curResult = rewriter.create <Torch::AtenSliceTensorOp>(
3858
+ loc, nmsResultTy, result, /* dim=*/ cst0, /* start=*/ cst0,
3859
+ /* end=*/ maxOutputBoxesPerClass, /* step=*/ cst1);
3860
+ rewriter.create <Torch::PrimIfYieldOp>(loc, curResult);
3861
+ }
3862
+ {
3863
+ PatternRewriter::InsertionGuard guard (rewriter);
3864
+ rewriter.createBlock (&ifSlice.getElseRegion (),
3865
+ ifSlice.getElseRegion ().begin ());
3866
+
3867
+ Value curResult = rewriter.create <Torch::TensorStaticInfoCastOp>(
3868
+ loc, nmsResultTy, result);
3869
+ rewriter.create <Torch::PrimIfYieldOp>(loc, curResult);
3870
+ }
3871
+ result = ifSlice.getResult (0 );
3872
+
3873
+ // The result generated by torchvision.nms op is of shape [n], while
3874
+ // the onnx expects it to be of shape [n, 3]. Hence, we unsqueeze
3875
+ // the tensor and make it of shape [n, 1] and then concatenate it
3876
+ // with batch and class values to make it shape [n, 3].
3877
+ FailureOr<Value> unsqueezedResult =
3878
+ Torch::unsqueezeTensor (rewriter, binder.op , result, cst1);
3879
+ if (failed (unsqueezedResult))
3880
+ return rewriter.notifyMatchFailure (
3881
+ binder.op , " failed to unsqueeze result tensor" );
3882
+ result = unsqueezedResult.value ();
3883
+
3884
+ auto resultNmsType = cast<Torch::ValueTensorType>(result.getType ());
3885
+ numOutputBoxes =
3886
+ rewriter.create <Torch::AtenSizeIntOp>(loc, result, cst0);
3887
+
3888
+ Value catList = rewriter.create <Torch::PrimListConstructOp>(
3889
+ loc, intListTy, SmallVector<Value>{numOutputBoxes, cst1});
3890
+
3891
+ Value resB = rewriter.create <Torch::AtenEmptyMemoryFormatOp>(
3892
+ loc, resultNmsType, catList, /* dtype=*/ cst4,
3893
+ /* layout=*/ cstNone,
3894
+ /* device=*/ cstNone, /* pinMemory=*/ cstNone,
3895
+ /* memoryFormat=*/ cstNone);
3896
+ auto bVal = rewriter.create <Torch::AtenFillScalarOp>(
3897
+ loc, resultNmsType, resB, batchIV);
3898
+
3899
+ Value resC = rewriter.create <Torch::AtenEmptyMemoryFormatOp>(
3900
+ loc, resultNmsType, catList, /* dtype=*/ cst4,
3901
+ /* layout=*/ cstNone,
3902
+ /* device=*/ cstNone, /* pinMemory=*/ cstNone,
3903
+ /* memoryFormat=*/ cstNone);
3904
+ auto cVal = rewriter.create <Torch::AtenFillScalarOp>(
3905
+ loc, resultNmsType, resC, classIV);
3906
+
3907
+ Type listElemType =
3908
+ cast<Torch::BaseTensorType>(resultType)
3909
+ .getWithSizesAndDtype (/* optionalSizes=*/ std::nullopt,
3910
+ /* optionalDtype=*/ nullptr );
3911
+ Type listType = Torch::ListType::get (listElemType);
3912
+
3913
+ Value classResList = rewriter.create <Torch::PrimListConstructOp>(
3914
+ loc, listType, SmallVector<Value>{cVal, result});
3915
+
3916
+ auto cat1Type = rewriter.getType <Torch::ValueTensorType>(
3917
+ ArrayRef<int64_t >{-1 , 2 }, resultNmsType.getDtype ());
3918
+
3919
+ auto cat1 = rewriter.create <Torch::AtenCatOp>(loc, cat1Type,
3920
+ classResList, cst1);
3921
+
3922
+ auto cat2Type = rewriter.getType <Torch::ValueTensorType>(
3923
+ SmallVector<int64_t >{-1 , 3 }, resultNmsType.getDtype ());
3924
+
3925
+ Value batchClassResList =
3926
+ rewriter.create <Torch::PrimListConstructOp>(
3927
+ loc, listType, SmallVector<Value>{bVal, cat1});
3928
+ auto cat2 = rewriter.create <Torch::AtenCatOp>(
3929
+ loc, cat2Type, batchClassResList, cst1);
3930
+
3931
+ Value appendRes = rewriter.create <Torch::PrimListConstructOp>(
3932
+ loc, listType, SmallVector<Value>{currRes, cat2});
3933
+ auto catResult = rewriter.create <Torch::AtenCatOp>(loc, cat2Type,
3934
+ appendRes, cst0);
3935
+ Value next =
3936
+ rewriter.create <Torch::AtenAddIntOp>(loc, finalResIdx, cst1);
3937
+ rewriter.create <Torch::PrimLoopConditionOp>(
3938
+ loc, cstTrue, ValueRange ({catResult, next}));
3939
+ }
3940
+ rewriter.create <Torch::PrimLoopConditionOp>(
3941
+ loc, cstTrue,
3942
+ ValueRange (
3943
+ {nmsClassLoop.getResult (0 ), nmsClassLoop.getResult (1 )}));
3799
3944
}
3800
- {
3801
- PatternRewriter::InsertionGuard guard (rewriter);
3802
- rewriter.createBlock (&ifSlice.getElseRegion (),
3803
- ifSlice.getElseRegion ().begin ());
3804
-
3805
- Value curResult = rewriter.create <Torch::TensorStaticInfoCastOp>(
3806
- loc, nmsResultTy, result);
3807
- rewriter.create <Torch::PrimIfYieldOp>(loc, curResult);
3808
- }
3809
- result = ifSlice.getResult (0 );
3810
-
3811
- // The result generated by torchvision.nms op is of shape [n], while the
3812
- // onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
3813
- // and make it of shape [n, 1] and then concatenate it with a zero
3814
- // tensor of shape [n, 2] to make it of shape [n, 3].
3815
- FailureOr<Value> unsqueezedResult =
3816
- Torch::unsqueezeTensor (rewriter, binder.op , result, cst1);
3817
- if (failed (unsqueezedResult))
3818
- return rewriter.notifyMatchFailure (
3819
- binder.op , " failed to unsqueeze result tensor" );
3820
- result = unsqueezedResult.value ();
3821
-
3822
- numOutputBoxes =
3823
- rewriter.create <Torch::AtenSizeIntOp>(loc, result, cst0);
3824
- SmallVector<Value> zerosShapeValues{numOutputBoxes};
3825
- zerosShapeValues.push_back (rewriter.create <Torch::ConstantIntOp>(
3826
- loc, rewriter.getI64IntegerAttr (2 )));
3827
- Value zerosShapeList = rewriter.create <Torch::PrimListConstructOp>(
3828
- loc,
3829
- rewriter.getType <Torch::ListType>(
3830
- rewriter.getType <Torch::IntType>()),
3831
- zerosShapeValues);
3832
- std::optional<ArrayRef<int64_t >> resultShape =
3833
- cast<Torch::ValueTensorType>(result.getType ()).getOptionalSizes ();
3834
- if (!resultShape.has_value ())
3835
- return rewriter.notifyMatchFailure (
3836
- binder.op , " expected result tensor to have shape" );
3837
- llvm::SmallVector<int64_t > zerosShape = {resultShape->front (), 2 };
3838
- auto zerosTy = Torch::ValueTensorType::get (
3839
- resultType.getContext (), zerosShape, resultType.getOptionalDtype ());
3840
- Value cstNone = rewriter.create <Torch::ConstantNoneOp>(loc);
3841
- Value zeros = rewriter.create <Torch::AtenZerosOp>(
3842
- loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone);
3843
-
3844
- Type listElemType =
3845
- cast<Torch::BaseTensorType>(resultType)
3846
- .getWithSizesAndDtype (/* optionalSizes=*/ std::nullopt,
3847
- /* optionalDtype=*/ nullptr );
3848
- Type listType = Torch::ListType::get (listElemType);
3849
- Value tensorList = rewriter.create <Torch::PrimListConstructOp>(
3850
- loc, listType, SmallVector<Value>{zeros, result});
3851
- rewriter.replaceOpWithNewOp <Torch::AtenCatOp>(binder.op , resultType,
3852
- tensorList, cst1);
3945
+
3946
+ rewriter.replaceOpWithNewOp <Torch::TensorStaticInfoCastOp>(
3947
+ binder.op , resultType, nmsBatchLoop.getResult (0 ));
3853
3948
return success ();
3854
3949
});
3855
3950
}
0 commit comments