@@ -1616,6 +1616,223 @@ class ConvertAtenAdaptivePoolOp : public OpConversionPattern<OpTy> {
1616
1616
};
1617
1617
} // namespace
1618
1618
1619
+ namespace {
1620
+ template <typename OpTy, typename PoolingOpTy, int Dim>
1621
+ class ConvertRoiAlignOp : public OpConversionPattern <OpTy> {
1622
+ public:
1623
+ using OpConversionPattern<OpTy>::OpConversionPattern;
1624
+ LogicalResult
1625
+ matchAndRewrite (OpTy op, typename OpTy::Adaptor adaptor,
1626
+ ConversionPatternRewriter &rewriter) const override {
1627
+ if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
1628
+ return failure ();
1629
+
1630
+ Location loc = op->getLoc ();
1631
+ const TypeConverter *typeConverter = this ->getTypeConverter ();
1632
+ Value result = op.getResult ();
1633
+
1634
+ uint64_t pooledHeight =
1635
+ cast<ConstantIntOp>(op.getPooledHeight ().getDefiningOp ()).getValue ();
1636
+ uint64_t pooledWidth =
1637
+ cast<ConstantIntOp>(op.getPooledWidth ().getDefiningOp ()).getValue ();
1638
+ uint64_t samplingRatio =
1639
+ cast<ConstantIntOp>(op.getSamplingRatio ().getDefiningOp ()).getValue ();
1640
+ Value pooledH = op.getPooledHeight ();
1641
+ Value pooledW = op.getPooledWidth ();
1642
+ Value spatialScaleVal = op.getSpatialScale ();
1643
+ llvm::APFloat spatialScale =
1644
+ cast<ConstantFloatOp>(op.getSpatialScale ().getDefiningOp ()).getValue ();
1645
+ Value rois = op.getRois ();
1646
+ Value input = op.getInput ();
1647
+ // RankedTensorType inputType = input.getType();
1648
+ Value offset =
1649
+ rewriter.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.0 ));
1650
+ Type resultType = cast<RankedTensorType>(result.getType ());
1651
+ Type resultElementType = resultType.getElementType ();
1652
+ if (!op.getAligned ()) {
1653
+ offset = rewriter.create <arith::ConstantOp>(loc, b.getF32FloatAttr (0.5 ));
1654
+ }
1655
+
1656
+ Value lb = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
1657
+ Value ub0 = rewriter.create <tensor::DimOp>(loc, rois, 0 );
1658
+ Value ub1 = rewriter.create <tensor::DimOp>(loc, input, 1 );
1659
+ Value step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
1660
+ SmallVector<Value> finalOutputShape = {ub0, ub1, pooledH, pooledW};
1661
+ Value finalOutputTensor = rewriter.create <tensor::EmptyOp>(
1662
+ loc, getAsOpFoldResult (finalOutputShape), resultElementType);
1663
+ auto forLoop = rewriter.create <scf::ForOp>(
1664
+ loc, lb, ub0, step, ValueRange{},
1665
+ [&](OpBuilder &b1, Location loc, Value iv0, ValueRange args) {
1666
+ auto forLoop = b1.create <scf::ForOp>(
1667
+ loc, lb, ub1, step, ValueRange{},
1668
+ [&](OpBuilder &b, Location loc, Value iv1, ValueRange args) {
1669
+ // Step 1: Extract bounds for region of interest (roi)
1670
+ OpFoldResult zeroAttr = b.getI64IntegerAttr (0 );
1671
+ OpFoldResult oneAttr = b.getI64IntegerAttr (1 );
1672
+ OpFoldResult twoAttr = b.getI64IntegerAttr (2 );
1673
+ OpFoldResult threeAttr = b.getI64IntegerAttr (3 );
1674
+ OpFoldResult fourAttr = b.getI64IntegerAttr (4 );
1675
+ OpFoldResult fiveAttr = b.getI64IntegerAttr (5 );
1676
+ // SmallVector<Value> offsetVals{iv0, zeroAttr};
1677
+ // SmallVector<OpFoldResult> sizeVals{oneAttr, fiveAttr};
1678
+ SmallVector<OpFoldResult> strideVals{oneAttr, oneAttr, oneAttr,
1679
+ oneAttr};
1680
+ // Value extractRoiBounds = b.create<tensor::ExtractSliceOp>(
1681
+ // loc, rois, offsetVals, sizeVals, strideVals);
1682
+ Value lowY = b.create <tensor::ExtractOp>(
1683
+ loc, rois, ValueRange{iv0, oneAttr});
1684
+ Value lowX = b.create <tensor::ExtractOp>(
1685
+ loc, rois, ValueRange{iv0, twoAttr});
1686
+ Value highY = b.create <tensor::ExtractOp>(
1687
+ loc, rois, ValueRange{iv0, threeAttr});
1688
+ Value highX = b.create <tensor::ExtractOp>(
1689
+ loc, rois, ValueRange{iv0, fourAttr});
1690
+
1691
+ lowY = b.create <arith::MulFOp>(loc, lowY, spatialScaleVal);
1692
+ lowX = b.create <arith::MulFOp>(loc, lowX, spatialScaleVal);
1693
+ highY = b.create <arith::MulFOp>(loc, highY, spatialScaleVal);
1694
+ highX = b.create <arith::MulFOp>(loc, highX, spatialScaleVal);
1695
+
1696
+ lowY = b.create <arith::SubFOp>(loc, lowY, offset);
1697
+ lowX = b.create <arith::SubFOp>(loc, lowX, offset);
1698
+ highY = b.create <arith::SubFOp>(loc, highY, offset);
1699
+ highX = b.create <arith::SubFOp>(loc, highX, offset);
1700
+
1701
+ // Step 2: Extract region of interest using bounds
1702
+ Value lowY_int = b.create <math::FloorOp>(loc, lowY);
1703
+ Value lowX_int = b.create <math::FloorOp>(loc, lowX);
1704
+ Value highY_int = b.create <math::CeilOp>(loc, highY);
1705
+ Value highX_int = b.create <math::CeilOp>(loc, highX);
1706
+ lowY_int =
1707
+ b.create <arith::FPToSIOp>(loc, b.getI64Type (), lowY_int);
1708
+ lowX_int =
1709
+ b.create <arith::FPToSIOp>(loc, b.getI64Type (), lowX_int);
1710
+ highY_int =
1711
+ b.create <arith::FPToSIOp>(loc, b.getI64Type (), highY_int);
1712
+ highX_int =
1713
+ b.create <arith::FPToSIOp>(loc, b.getI64Type (), highX_int);
1714
+
1715
+ Value roiHeight =
1716
+ b.create <arith::SubIOp>(loc, highY_int, lowY_int);
1717
+ Value roiWidth =
1718
+ b.create <arith::SubIOp>(loc, highX_int, lowX_int);
1719
+
1720
+ SmallVector<Value> roiOffsetVals{zeroAttr, iv1, lowY_int,
1721
+ lowX_int};
1722
+ SmallVector<Value> roiSizeVals{oneAttr, oneAttr, roiHeight,
1723
+ roiWidth};
1724
+
1725
+ Value extractRoi = b.create <tensor::ExtractSliceOp>(
1726
+ loc, input, roiOffsetVals, roiSizeVals, strideVals);
1727
+
1728
+ // Step 3: Perform bilinear interpolation over roi
1729
+ Value roiBinH = b.create <arith::SubOp>(loc, highY, lowY);
1730
+ Value roiBinW = b.create <arith::SubOp>(loc, highX, lowX);
1731
+ Value scaleH = b.create <arith::DivOp>(loc, roiBinH, pooledH);
1732
+ Value scaleW = b.create <arith::DivOp>(loc, roiBinW, pooledW);
1733
+ scaleH = b.create <arith::CeilOp>(loc, scaleH);
1734
+ scaleW = b.create <arith::CeilOp>(loc, scaleW);
1735
+ scaleH = b.create <arith::FPToSIOp>(loc, b.getI64Type (), scaleH);
1736
+ scaleW = b.create <arith::FPToSIOp>(loc, b.getI64Type (), scaleW);
1737
+
1738
+ Value roiSampleHeight =
1739
+ b.create <arith::MulIOp>(loc, pooledH, scaleH);
1740
+ Value roiSampleWidth =
1741
+ b.create <arith::MulIOp>(loc, pooledW, scaleW);
1742
+
1743
+ SmallVector<Value> outputSizeIntValues = {roiSampleHeight,
1744
+ roiSampleWidth};
1745
+ SmallVector<Value> dims =
1746
+ getTensorSizesUntilDim (b, loc, extractRoi, 1 );
1747
+ for (unsigned i = 2 ; i < inputRank; i++) {
1748
+ dims.push_back (
1749
+ castIntToIndex (b, loc, outputSizeIntValues[i - 2 ]));
1750
+ }
1751
+ SmallVector<Value> inputSizes;
1752
+ auto inputType = cast<RankedTensorType>(extractRoi.getType ());
1753
+ auto inputRank = inputType.getRank ();
1754
+ for (unsigned i = 2 ; i < inputRank; i++) {
1755
+ Value inputSize = getDimOp (b, loc, extractRoi, i);
1756
+ inputSizes.push_back (b.create <arith::IndexCastOp>(
1757
+ loc, b.getIntegerType (64 ), roiSizeVals[i]));
1758
+ }
1759
+ Value outTensor = b.create <tensor::EmptyOp>(
1760
+ loc, getAsOpFoldResult (dims), inputType.getElementType ());
1761
+ AffineMap idMap = b.getMultiDimIdentityMap (inputRank);
1762
+ SmallVector<utils::IteratorType> iteratorTypes (
1763
+ inputRank, utils::IteratorType::parallel);
1764
+ Value bilinearInterpolatedRoi =
1765
+ b.create <linalg::GenericOp>(
1766
+ loc, outTensor.getType (), ValueRange{}, outTensor,
1767
+ /* indexingMaps=*/ idMap,
1768
+ /* iteratorTypes=*/ iteratorTypes,
1769
+ [&](OpBuilder &b, Location loc, ValueRange args) {
1770
+ Value retVal = bilinearInterpolate (
1771
+ b, op, loc, outputSizeIntValues, extractRoi,
1772
+ inputSizes, ValueRange{}, " bilinear" );
1773
+ b.create <linalg::YieldOp>(loc, retVal);
1774
+ })
1775
+ .getResult (0 );
1776
+
1777
+ // Step 4: Sum pool over interpolated values
1778
+ Value sumPool, paddedInput;
1779
+ SmallVector<Value> kernelSizeIntValues = {oneAttr, oneAttr,
1780
+ scaleH, scaleW};
1781
+ SmallVector<Value, 2 > strideInts = {scaleH, scaleW};
1782
+ SmallVector<Value, 2 > paddingInts = {zeroAttr, zeroAttr};
1783
+ SmallVector<Value, 2 > dilationInts (oneAttr, 2 );
1784
+ SmallVector<Value, 4 > outTensorShape;
1785
+ if (failed (createPoolingOp<linalg::PoolingNchwSumOp>(
1786
+ op, b, self, /* supportNonFPInput=*/ true , false ,
1787
+ /* dimensionality=*/ 2 , kernelSizeIntValues, strideInts,
1788
+ paddingInts, dilationInts,
1789
+ b.getZeroAttr (resultElementType), outTensorShape,
1790
+ paddedInput, sumPool)))
1791
+ return b.notifyMatchFailure (op, " unable to compute sumpool" );
1792
+
1793
+ // Step 5: elementwise division by number of sampling points
1794
+ // to compute avg pool
1795
+ Value outputTensor = b.create <tensor::EmptyOp>(
1796
+ loc, getAsOpFoldResult (outTensorShape), resultElementType);
1797
+ Value divisor = b.create <arith::MulIOp>(loc, scaleH, scaleW);
1798
+ Value avgPool =
1799
+ b.create <linalg::GenericOp>(
1800
+ loc, outputTensor.getType (), sumPool, outputTensor,
1801
+ /* indexingMaps=*/ indexingMapsAvg,
1802
+ /* iteratorTypes=*/ iteratorTypesAvg,
1803
+ [&](OpBuilder &b, Location loc, ValueRange args) {
1804
+ Value avg;
1805
+ if (isa<mlir::IntegerType>(resultElementType))
1806
+ avg = b.create <arith::DivSIOp>(loc, args[0 ],
1807
+ divisor);
1808
+ else if (isa<mlir::FloatType>(resultElementType))
1809
+ avg =
1810
+ b.create <arith::DivFOp>(loc, args[0 ], divisor);
1811
+ b.create <linalg::YieldOp>(loc, avg);
1812
+ })
1813
+ .getResult (0 );
1814
+
1815
+ SmallVector<OpFoldResult> finalStrides (inputRank, oneAttr);
1816
+ SmallVector<OpFoldResult> finalOffsets = {
1817
+ getAsOpFoldResult (iv0), getAsOpFoldResult (iv1), zeroAttr,
1818
+ zeroAttr};
1819
+ SmallVector<OpFoldResult> finalSizes = {
1820
+ oneAttr, oneAttr, getAsOpFoldResult (pooledH),
1821
+ getAsOpFoldResult (pooledW)};
1822
+ SmallVector<OpFoldResult> diagStrides (inputRank, oneAttr);
1823
+ finalOutputTensor = b.create <tensor::InsertSliceOp>(
1824
+ loc, finalOutputTensor, avgPool, finalOffsets, finalSizes,
1825
+ finalStrides);
1826
+ });
1827
+ });
1828
+
1829
+ Type resultType = typeConverter->convertType (op.getType ());
1830
+ b.replaceOp (op, finalOutputTensor);
1831
+ return success ();
1832
+ }
1833
+ };
1834
+ } // namespace
1835
+
1619
1836
void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality (
1620
1837
TypeConverter &typeConverter, RewritePatternSet &patterns,
1621
1838
ConversionTarget &target) {
0 commit comments