Skip to content

Commit e0d6f02

Browse files
Initial commit does not build
1 parent 596b58e commit e0d6f02

File tree

1 file changed

+217
-0
lines changed

1 file changed

+217
-0
lines changed

lib/Conversion/TorchToLinalg/Pooling.cpp

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,6 +1616,223 @@ class ConvertAtenAdaptivePoolOp : public OpConversionPattern<OpTy> {
16161616
};
16171617
} // namespace
16181618

1619+
namespace {
1620+
template <typename OpTy, typename PoolingOpTy, int Dim>
1621+
class ConvertRoiAlignOp : public OpConversionPattern<OpTy> {
1622+
public:
1623+
using OpConversionPattern<OpTy>::OpConversionPattern;
1624+
LogicalResult
1625+
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
1626+
ConversionPatternRewriter &rewriter) const override {
1627+
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
1628+
return failure();
1629+
1630+
Location loc = op->getLoc();
1631+
const TypeConverter *typeConverter = this->getTypeConverter();
1632+
Value result = op.getResult();
1633+
1634+
uint64_t pooledHeight =
1635+
cast<ConstantIntOp>(op.getPooledHeight().getDefiningOp()).getValue();
1636+
uint64_t pooledWidth =
1637+
cast<ConstantIntOp>(op.getPooledWidth().getDefiningOp()).getValue();
1638+
uint64_t samplingRatio =
1639+
cast<ConstantIntOp>(op.getSamplingRatio().getDefiningOp()).getValue();
1640+
Value pooledH = op.getPooledHeight();
1641+
Value pooledW = op.getPooledWidth();
1642+
Value spatialScaleVal = op.getSpatialScale();
1643+
llvm::APFloat spatialScale =
1644+
cast<ConstantFloatOp>(op.getSpatialScale().getDefiningOp()).getValue();
1645+
Value rois = op.getRois();
1646+
Value input = op.getInput();
1647+
// RankedTensorType inputType = input.getType();
1648+
Value offset =
1649+
rewriter.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
1650+
Type resultType = cast<RankedTensorType>(result.getType());
1651+
Type resultElementType = resultType.getElementType();
1652+
if (!op.getAligned()) {
1653+
offset = rewriter.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
1654+
}
1655+
1656+
Value lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1657+
Value ub0 = rewriter.create<tensor::DimOp>(loc, rois, 0);
1658+
Value ub1 = rewriter.create<tensor::DimOp>(loc, input, 1);
1659+
Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1660+
SmallVector<Value> finalOutputShape = {ub0, ub1, pooledH, pooledW};
1661+
Value finalOutputTensor = rewriter.create<tensor::EmptyOp>(
1662+
loc, getAsOpFoldResult(finalOutputShape), resultElementType);
1663+
auto forLoop = rewriter.create<scf::ForOp>(
1664+
loc, lb, ub0, step, ValueRange{},
1665+
[&](OpBuilder &b1, Location loc, Value iv0, ValueRange args) {
1666+
auto forLoop = b1.create<scf::ForOp>(
1667+
loc, lb, ub1, step, ValueRange{},
1668+
[&](OpBuilder &b, Location loc, Value iv1, ValueRange args) {
1669+
// Step 1: Extract bounds for region of interest (roi)
1670+
OpFoldResult zeroAttr = b.getI64IntegerAttr(0);
1671+
OpFoldResult oneAttr = b.getI64IntegerAttr(1);
1672+
OpFoldResult twoAttr = b.getI64IntegerAttr(2);
1673+
OpFoldResult threeAttr = b.getI64IntegerAttr(3);
1674+
OpFoldResult fourAttr = b.getI64IntegerAttr(4);
1675+
OpFoldResult fiveAttr = b.getI64IntegerAttr(5);
1676+
// SmallVector<Value> offsetVals{iv0, zeroAttr};
1677+
// SmallVector<OpFoldResult> sizeVals{oneAttr, fiveAttr};
1678+
SmallVector<OpFoldResult> strideVals{oneAttr, oneAttr, oneAttr,
1679+
oneAttr};
1680+
// Value extractRoiBounds = b.create<tensor::ExtractSliceOp>(
1681+
// loc, rois, offsetVals, sizeVals, strideVals);
1682+
Value lowY = b.create<tensor::ExtractOp>(
1683+
loc, rois, ValueRange{iv0, oneAttr});
1684+
Value lowX = b.create<tensor::ExtractOp>(
1685+
loc, rois, ValueRange{iv0, twoAttr});
1686+
Value highY = b.create<tensor::ExtractOp>(
1687+
loc, rois, ValueRange{iv0, threeAttr});
1688+
Value highX = b.create<tensor::ExtractOp>(
1689+
loc, rois, ValueRange{iv0, fourAttr});
1690+
1691+
lowY = b.create<arith::MulFOp>(loc, lowY, spatialScaleVal);
1692+
lowX = b.create<arith::MulFOp>(loc, lowX, spatialScaleVal);
1693+
highY = b.create<arith::MulFOp>(loc, highY, spatialScaleVal);
1694+
highX = b.create<arith::MulFOp>(loc, highX, spatialScaleVal);
1695+
1696+
lowY = b.create<arith::SubFOp>(loc, lowY, offset);
1697+
lowX = b.create<arith::SubFOp>(loc, lowX, offset);
1698+
highY = b.create<arith::SubFOp>(loc, highY, offset);
1699+
highX = b.create<arith::SubFOp>(loc, highX, offset);
1700+
1701+
// Step 2: Extract region of interest using bounds
1702+
Value lowY_int = b.create<math::FloorOp>(loc, lowY);
1703+
Value lowX_int = b.create<math::FloorOp>(loc, lowX);
1704+
Value highY_int = b.create<math::CeilOp>(loc, highY);
1705+
Value highX_int = b.create<math::CeilOp>(loc, highX);
1706+
lowY_int =
1707+
b.create<arith::FPToSIOp>(loc, b.getI64Type(), lowY_int);
1708+
lowX_int =
1709+
b.create<arith::FPToSIOp>(loc, b.getI64Type(), lowX_int);
1710+
highY_int =
1711+
b.create<arith::FPToSIOp>(loc, b.getI64Type(), highY_int);
1712+
highX_int =
1713+
b.create<arith::FPToSIOp>(loc, b.getI64Type(), highX_int);
1714+
1715+
Value roiHeight =
1716+
b.create<arith::SubIOp>(loc, highY_int, lowY_int);
1717+
Value roiWidth =
1718+
b.create<arith::SubIOp>(loc, highX_int, lowX_int);
1719+
1720+
SmallVector<Value> roiOffsetVals{zeroAttr, iv1, lowY_int,
1721+
lowX_int};
1722+
SmallVector<Value> roiSizeVals{oneAttr, oneAttr, roiHeight,
1723+
roiWidth};
1724+
1725+
Value extractRoi = b.create<tensor::ExtractSliceOp>(
1726+
loc, input, roiOffsetVals, roiSizeVals, strideVals);
1727+
1728+
// Step 3: Perform bilinear interpolation over roi
1729+
Value roiBinH = b.create<arith::SubOp>(loc, highY, lowY);
1730+
Value roiBinW = b.create<arith::SubOp>(loc, highX, lowX);
1731+
Value scaleH = b.create<arith::DivOp>(loc, roiBinH, pooledH);
1732+
Value scaleW = b.create<arith::DivOp>(loc, roiBinW, pooledW);
1733+
scaleH = b.create<arith::CeilOp>(loc, scaleH);
1734+
scaleW = b.create<arith::CeilOp>(loc, scaleW);
1735+
scaleH = b.create<arith::FPToSIOp>(loc, b.getI64Type(), scaleH);
1736+
scaleW = b.create<arith::FPToSIOp>(loc, b.getI64Type(), scaleW);
1737+
1738+
Value roiSampleHeight =
1739+
b.create<arith::MulIOp>(loc, pooledH, scaleH);
1740+
Value roiSampleWidth =
1741+
b.create<arith::MulIOp>(loc, pooledW, scaleW);
1742+
1743+
SmallVector<Value> outputSizeIntValues = {roiSampleHeight,
1744+
roiSampleWidth};
1745+
SmallVector<Value> dims =
1746+
getTensorSizesUntilDim(b, loc, extractRoi, 1);
1747+
for (unsigned i = 2; i < inputRank; i++) {
1748+
dims.push_back(
1749+
castIntToIndex(b, loc, outputSizeIntValues[i - 2]));
1750+
}
1751+
SmallVector<Value> inputSizes;
1752+
auto inputType = cast<RankedTensorType>(extractRoi.getType());
1753+
auto inputRank = inputType.getRank();
1754+
for (unsigned i = 2; i < inputRank; i++) {
1755+
Value inputSize = getDimOp(b, loc, extractRoi, i);
1756+
inputSizes.push_back(b.create<arith::IndexCastOp>(
1757+
loc, b.getIntegerType(64), roiSizeVals[i]));
1758+
}
1759+
Value outTensor = b.create<tensor::EmptyOp>(
1760+
loc, getAsOpFoldResult(dims), inputType.getElementType());
1761+
AffineMap idMap = b.getMultiDimIdentityMap(inputRank);
1762+
SmallVector<utils::IteratorType> iteratorTypes(
1763+
inputRank, utils::IteratorType::parallel);
1764+
Value bilinearInterpolatedRoi =
1765+
b.create<linalg::GenericOp>(
1766+
loc, outTensor.getType(), ValueRange{}, outTensor,
1767+
/*indexingMaps=*/idMap,
1768+
/*iteratorTypes=*/iteratorTypes,
1769+
[&](OpBuilder &b, Location loc, ValueRange args) {
1770+
Value retVal = bilinearInterpolate(
1771+
b, op, loc, outputSizeIntValues, extractRoi,
1772+
inputSizes, ValueRange{}, "bilinear");
1773+
b.create<linalg::YieldOp>(loc, retVal);
1774+
})
1775+
.getResult(0);
1776+
1777+
// Step 4: Sum pool over interpolated values
1778+
Value sumPool, paddedInput;
1779+
SmallVector<Value> kernelSizeIntValues = {oneAttr, oneAttr,
1780+
scaleH, scaleW};
1781+
SmallVector<Value, 2> strideInts = {scaleH, scaleW};
1782+
SmallVector<Value, 2> paddingInts = {zeroAttr, zeroAttr};
1783+
SmallVector<Value, 2> dilationInts(oneAttr, 2);
1784+
SmallVector<Value, 4> outTensorShape;
1785+
if (failed(createPoolingOp<linalg::PoolingNchwSumOp>(
1786+
op, b, self, /*supportNonFPInput=*/true, false,
1787+
/*dimensionality=*/2, kernelSizeIntValues, strideInts,
1788+
paddingInts, dilationInts,
1789+
b.getZeroAttr(resultElementType), outTensorShape,
1790+
paddedInput, sumPool)))
1791+
return b.notifyMatchFailure(op, "unable to compute sumpool");
1792+
1793+
// Step 5: elementwise division by number of sampling points
1794+
// to compute avg pool
1795+
Value outputTensor = b.create<tensor::EmptyOp>(
1796+
loc, getAsOpFoldResult(outTensorShape), resultElementType);
1797+
Value divisor = b.create<arith::MulIOp>(loc, scaleH, scaleW);
1798+
Value avgPool =
1799+
b.create<linalg::GenericOp>(
1800+
loc, outputTensor.getType(), sumPool, outputTensor,
1801+
/*indexingMaps=*/indexingMapsAvg,
1802+
/*iteratorTypes=*/iteratorTypesAvg,
1803+
[&](OpBuilder &b, Location loc, ValueRange args) {
1804+
Value avg;
1805+
if (isa<mlir::IntegerType>(resultElementType))
1806+
avg = b.create<arith::DivSIOp>(loc, args[0],
1807+
divisor);
1808+
else if (isa<mlir::FloatType>(resultElementType))
1809+
avg =
1810+
b.create<arith::DivFOp>(loc, args[0], divisor);
1811+
b.create<linalg::YieldOp>(loc, avg);
1812+
})
1813+
.getResult(0);
1814+
1815+
SmallVector<OpFoldResult> finalStrides(inputRank, oneAttr);
1816+
SmallVector<OpFoldResult> finalOffsets = {
1817+
getAsOpFoldResult(iv0), getAsOpFoldResult(iv1), zeroAttr,
1818+
zeroAttr};
1819+
SmallVector<OpFoldResult> finalSizes = {
1820+
oneAttr, oneAttr, getAsOpFoldResult(pooledH),
1821+
getAsOpFoldResult(pooledW)};
1822+
SmallVector<OpFoldResult> diagStrides(inputRank, oneAttr);
1823+
finalOutputTensor = b.create<tensor::InsertSliceOp>(
1824+
loc, finalOutputTensor, avgPool, finalOffsets, finalSizes,
1825+
finalStrides);
1826+
});
1827+
});
1828+
1829+
Type resultType = typeConverter->convertType(op.getType());
1830+
b.replaceOp(op, finalOutputTensor);
1831+
return success();
1832+
}
1833+
};
1834+
} // namespace
1835+
16191836
void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
16201837
TypeConverter &typeConverter, RewritePatternSet &patterns,
16211838
ConversionTarget &target) {

0 commit comments

Comments
 (0)