diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 3c971354783a..60fbbf675a28 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -18,6 +18,9 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "llvm/Support/Debug.h" #include using namespace mlir; @@ -150,26 +153,29 @@ static LogicalResult createPoolingOp( SmallVectorImpl &dilationInts, Attribute initValueAttr, SmallVectorImpl &outTensorShape, Value &paddedInput, Value &result) { Location loc = op->getLoc(); + + Type elementType = cast(self.getType()).getElementType(); if (!isa(elementType) && !supportNonFPInput) return op->emitError("unimplemented: non-floating point type"); - + Value initValue = rewriter.create(loc, cast(initValueAttr)); paddedInput = padInputTensor(op, rewriter, self, ceilMode, dimensionality, strideInts, paddingInts, initValue); - + auto outTensorInitialized = computeOutputTensor( op, rewriter, self, dimensionality, ceilMode, strideInts, paddingInts, dilationInts, kernelSizeIntValues, outTensorShape, initValue); - + auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); auto shape = castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues); + Value windowTensor = rewriter.create( loc, getAsOpFoldResult(shape), elementType); - + Value permutedInput = paddedInput, permutedOutput = outTensorInitialized; if (dimensionality == 3) { // Permute input and output tensor as follows: @@ -187,7 +193,7 @@ static LogicalResult createPoolingOp( return rewriter.notifyMatchFailure( op, "failed to perform permutation of tensor"); } - + Value poolingResult = rewriter .create(loc, permutedOutput.getType(), @@ -1614,6 +1620,455 @@ class ConvertAtenAdaptivePoolOp : public OpConversionPattern { return success(); } }; + +struct ConvertRoiAlignOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static SmallVector + coordinateTransform(OpBuilder &b, Torch::TorchvisionRoiAlignOp op, + Location loc, SmallVector outputSizes, Value input, + SmallVector inputSizes, + SmallVector scaleValues, std::string coordStr, + bool alignCornersBool, SmallVector indices, + bool clip) { + + unsigned dimOffset = 2; + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + Value cstOneFloat = + b.create(loc, b.getF32FloatAttr(1.0)); + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value zero = b.create(loc, b.getF32FloatAttr(0.0)); + + SmallVector proj; + for (unsigned i = 0; i < inputRank - dimOffset; i++) { + // length_original + Value inputFP = + b.create(loc, b.getF32Type(), inputSizes[i]); + // length_resized + Value outputSizeFP = + b.create(loc, b.getF32Type(), outputSizes[i]); + // scale = length_resized/length_original + Value scale; + + if (alignCornersBool) { + // x_original = x_resized * (length_original - 1) / (length_resized - 1) + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + Value outputSizeSubOne = + b.create(loc, outputSizeFP, cstOneFloat); + Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, + outputSizeSubOne, zero); + scale = b.create(loc, inputSubOne, outputSizeSubOne); + scale = b.create(loc, cmp, zero, scale); + coordStr = "_align_corners"; + } else if (scaleValues.empty()) + scale = b.create(loc, outputSizeFP, inputFP); + else + scale = scaleValues[i]; + // y_resized + Value outInt = b.create(loc, b.getI64Type(), + indices[i + dimOffset]); + Value outFP = b.create(loc, b.getF32Type(), outInt); + Value preClip; + if (coordStr == "_align_corners") { + preClip = b.create(loc, outFP, scale); + } + if (coordStr == "_asymmetric") { + preClip = b.create(loc, outFP, scale); + } + if (coordStr == "_pytorch_half_pixel" || coordStr == "" || + coordStr == "_half_pixel_symmetric") { + // half-pixel modes + // y_resized + 0.5 + Value outPlusHalf = b.create(loc, outFP, cstHalf); + // (y_resized + 0.5) / scale + Value outDivScale = b.create(loc, outPlusHalf, scale); + // _ - 0.5 + preClip = b.create(loc, outDivScale, cstHalf); + } + // for half_pixel_symmetric, need to compute offset from raw scales + if (coordStr == "_half_pixel_symmetric" && !scaleValues.empty()) { + Value outputSizeFromScale = + b.create(loc, inputFP, scale); + Value adjustment = + b.create(loc, outputSizeFP, outputSizeFromScale); + Value cstTwo = b.create(loc, b.getF32FloatAttr(2.0)); + Value center = b.create(loc, inputFP, cstTwo); + Value oneMAdjustment = + b.create(loc, cstOneFloat, adjustment); + Value offset = b.create(loc, center, oneMAdjustment); + preClip = b.create(loc, offset, preClip); + } + + // for pytorch half pixel , special case for length_resized == 1: + if (coordStr == "_pytorch_half_pixel") { + + Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, + outputSizeFP, cstOneFloat); + + preClip = b.create(loc, cmp, zero, preClip); + } + if (clip) { + // preClip is the fp position inside the input image to extract from. + // clip to [0,inf) + + Value max = b.create(loc, preClip, zero); + + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + // clip to [0,length_original - 1]. + // proj is properly within the input image. + + proj.push_back(b.create(loc, max, inputSubOne)); + + } else { + + proj.push_back(preClip); + } + } + + return proj; + } + + static Value bilinearInterpolate(OpBuilder &b, + Torch::TorchvisionRoiAlignOp op, + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { + unsigned dimOffset = 2; + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + Value cstOneFloat = + b.create(loc, b.getF32FloatAttr(1.0)); + + SmallVector indices; + for (unsigned i = 0; i < inputRank; ++i) { + indices.push_back(b.create(loc, i)); + } + SmallVector proj, high, low, highFP, lowFP; + + proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes, + scaleValues, coordStr, false, indices, true); + for (unsigned i = 0; i < inputRank - dimOffset; ++i) { + // length_original + Value inputFP = + b.create(loc, b.getF32Type(), inputSizes[i]); + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + + // for bilinear interpolation, we look for the nearest indices below and + // above proj. + lowFP.push_back(b.create(loc, proj[i])); + Value projPlusOne = b.create(loc, cstOneFloat, proj[i]); + highFP.push_back(b.create(loc, projPlusOne)); + + Value lowInt = b.create(loc, b.getI64Type(), lowFP[i]); + low.push_back( + b.create(loc, b.getIndexType(), lowInt)); + + // highFP could be out-of-bounds, so make sure to clip it down before + // extracting. If highFP actually gets clipped here, then high[i] will + // extract at the last pixel, but will treat it as if it were extracted + // from one further position when computing the interpolation weights. + Value highExtract = + b.create(loc, projPlusOne, inputSubOne); + highExtract = b.create(loc, b.getI64Type(), highExtract); + + high.push_back( + b.create(loc, b.getIndexType(), highExtract)); + } + indices[dimOffset] = low[0]; + indices[dimOffset + 1] = low[1]; + Value p00 = b.create(loc, input, indices); + + indices[dimOffset] = low[0]; + indices[dimOffset + 1] = high[1]; + Value p01 = b.create(loc, input, indices); + + indices[dimOffset] = high[0]; + indices[dimOffset + 1] = low[1]; + Value p10 = b.create(loc, input, indices); + + indices[dimOffset] = high[0]; + indices[dimOffset + 1] = high[1]; + Value p11 = b.create(loc, input, indices); + + // Let Aij := area rect((yProj,xProj) <-> (y_i*,x_j*)), + // where i* = i+1 mod 2 and x_0 = xLow, x_1 = xHigh etc. + // We interpolate via the weighted average of pij by weights Aij + // the formula is retval = Sum(pij*Aij for i and j in range(2)). + // Note: we do not need to divide by total rect area == 1. + + // lengths : Aij == dyi*dxj + Value dy0 = b.create(loc, highFP[0], proj[0]); + Value dy1 = b.create(loc, proj[0], lowFP[0]); + Value dx0 = b.create(loc, highFP[1], proj[1]); + Value dx1 = b.create(loc, proj[1], lowFP[1]); + + // left = A00*p00 + A01*p01 = dy0(dx0p00 + dx1p01) + Value dx0p00 = b.create(loc, dx0, p00); + Value dx1p01 = b.create(loc, dx1, p01); + Value sum = b.create(loc, dx0p00, dx1p01); + Value left = b.create(loc, dy0, sum); + // right = A10*p10 + A11*p11 = dy1(dx0p10 + dx1p11) + + Value dx0p10 = b.create(loc, dx0, p10); + Value dx1p11 = b.create(loc, dx1, p11); + sum = b.create(loc, dx0p10, dx1p11); + Value right = b.create(loc, dy1, sum); + + return b.create(loc, left, right); + } + LogicalResult + matchAndRewrite(Torch::TorchvisionRoiAlignOp op, + typename Torch::TorchvisionRoiAlignOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op->getLoc(); + Value result = op.getResult(); + + uint64_t samplingRatio = + cast(op.getSamplingRatio().getDefiningOp()).getValue(); + int64_t samplingRatioInt = static_cast(samplingRatio); + Value pooledH = adaptor.getPooledHeight(); + Value pooledW = adaptor.getPooledWidth(); + Value pooledHFp = rewriter.create(loc, rewriter.getF32Type(), pooledH); + Value pooledWFp = rewriter.create(loc, rewriter.getF32Type(), pooledW); + + // Value spatialScaleVal = adaptor.getSpatialScale(); + llvm::APFloat spatialScale = + cast(op.getSpatialScale().getDefiningOp()).getValue(); + Value spatialScaleVal = rewriter.create( + loc, rewriter.getF32FloatAttr(spatialScale.convertToDouble())); + Value rois = adaptor.getRois(); + Value input = adaptor.getInput(); + RankedTensorType inputType = dyn_cast_or_null( + this->getTypeConverter()->convertType(op.getInput().getType())); + + if (inputType == nullptr) { + op.emitError("Cannot determine input shape"); + } + + unsigned inputRank = inputType.getRank(); + Value offset = + rewriter.create(loc, rewriter.getF32FloatAttr(0.0)); + RankedTensorType resultType = dyn_cast_or_null( + this->getTypeConverter()->convertType(result.getType())); + if (resultType == nullptr) { + op.emitError("Cannot determine result shape"); + } + + Type resultElementType = resultType.getElementType(); + if (!op.getAligned()) { + offset = rewriter.create( + loc, rewriter.getF32FloatAttr(0.5)); + } + + Value lb = rewriter.create(loc, 0); + Value ub0 = rewriter.create(loc, rois, 0); + Value ub1 = rewriter.create(loc, input, 1); + Value step = rewriter.create(loc, 1); + auto pooledHIdx = rewriter.create( + loc, rewriter.getIndexType(), pooledH); + auto pooledWIdx = rewriter.create( + loc, rewriter.getIndexType(), pooledW); + SmallVector finalOutputShape = {ub0, ub1, pooledHIdx, pooledWIdx}; + Value finalOutputTensor = rewriter.create( + loc, getAsOpFoldResult(finalOutputShape), resultElementType); + auto resOut = rewriter.create( + loc, lb, ub0, step, ValueRange{finalOutputTensor}, + [&](OpBuilder &b, Location loc, Value iv0, ValueRange args0) { + auto res = b.create( + loc, lb, ub1, step, ValueRange{args0[0]}, + [&](OpBuilder &b, Location loc, Value iv1, ValueRange args) { + // Step 1: Extract bounds for region of interest (roi). + OpFoldResult zeroAttr = b.getI64IntegerAttr(0); + OpFoldResult oneAttr = b.getI64IntegerAttr(1); + Value intOne = + b.create(loc, b.getI64IntegerAttr(1)); + Value idxZero = rewriter.create(loc, 0); + Value idxOne = rewriter.create(loc, 1); + Value cstTwo = rewriter.create(loc, 2); + Value cstThree = + rewriter.create(loc, 3); + Value cstFour = rewriter.create(loc, 4); + + SmallVector strideVals{oneAttr, oneAttr, oneAttr, + oneAttr}; + + SmallVector lowYIndices = {iv0, idxOne}; + Value lowY = b.create(loc, b.getF32Type(), + rois, lowYIndices); + + SmallVector lowXIndices = {iv0, cstTwo}; + + Value lowX = b.create(loc, b.getF32Type(), + rois, lowXIndices); + + SmallVector highYIndices = {iv0, cstThree}; + + Value highY = b.create(loc, b.getF32Type(), + rois, highYIndices); + SmallVector highXIndices = {iv0, cstFour}; + + Value highX = b.create(loc, b.getF32Type(), + rois, highXIndices); + lowY = b.create(loc, lowY, spatialScaleVal); + lowX = b.create(loc, lowX, spatialScaleVal); + highY = b.create(loc, highY, spatialScaleVal); + highX = b.create(loc, highX, spatialScaleVal); + lowY = b.create(loc, lowY, offset); + lowX = b.create(loc, lowX, offset); + highY = b.create(loc, highY, offset); + highX = b.create(loc, highX, offset); + + // Step 2: Extract region of interest using bounds + Value lowYInt = b.create(loc, lowY); + Value lowXInt = b.create(loc, lowX); + Value highYInt = b.create(loc, highY); + Value highXInt = b.create(loc, highX); + lowYInt = + b.create(loc, b.getI64Type(), lowYInt); + lowXInt = + b.create(loc, b.getI64Type(), lowXInt); + highYInt = + b.create(loc, b.getI64Type(), highYInt); + highXInt = + b.create(loc, b.getI64Type(), highXInt); + Value lowYIdx = b.create(loc, b.getIndexType(), lowYInt); + Value lowXIdx = b.create(loc, b.getIndexType(), lowXInt); + Value roiHeight = + b.create(loc, highYInt, lowYInt); + Value roiWidth = + b.create(loc, highXInt, lowXInt); + Value roiHIdx = b.create(loc, b.getIndexType(), roiHeight); + Value roiWIdx = b.create(loc, b.getIndexType(), roiWidth); + + SmallVector roiOffsetVals = { + getAsOpFoldResult(idxZero), getAsOpFoldResult(iv1), + getAsOpFoldResult(lowYInt), getAsOpFoldResult(lowXInt)}; + SmallVector roiSizeVals = {intOne, intOne, roiHeight, + roiWidth}; + + Value extractRoi = b.create( + loc, input, ValueRange{idxZero, iv1, lowYIdx, lowXIdx}, + ValueRange{idxOne, idxOne, roiHIdx, roiWIdx}, + ValueRange{idxOne, idxOne, idxOne, idxOne}); + + // Step 3: Perform bilinear interpolation over roi. + Value roiBinH = b.create(loc, highY, lowY); + Value roiBinW = b.create(loc, highX, lowX); + Value scaleH = b.create(loc, roiBinH, pooledHFp); + Value scaleW = b.create(loc, roiBinW, pooledWFp); + scaleH = b.create(loc, scaleH); + scaleW = b.create(loc, scaleW); + scaleH = b.create(loc, b.getI64Type(), scaleH); + scaleW = b.create(loc, b.getI64Type(), scaleW); + + if (samplingRatio > 0) { + scaleH = b.create( + loc, rewriter.getI64IntegerAttr(samplingRatio)); + scaleW = b.create( + loc, rewriter.getI64IntegerAttr(samplingRatio)); + } + + Value roiSampleHeight = + b.create(loc, pooledH, scaleH); + Value roiSampleWidth = + b.create(loc, pooledW, scaleW); + + SmallVector outputSizeIntValues = {roiSampleHeight, + roiSampleWidth}; + SmallVector dims = + getTensorSizesUntilDim(b, loc, extractRoi, 1); + + for (unsigned i = 2; i < inputRank; ++i) { + auto dim = b.create( + loc, b.getIndexType(), outputSizeIntValues[i - 2]); + dims.push_back(dim); + } + + SmallVector inputSizes; + for (unsigned i = 2; i < inputRank; ++i) { + inputSizes.push_back(roiSizeVals[i]); + } + + Value outTensor = b.create( + loc, getAsOpFoldResult(dims), inputType.getElementType()); + auto iteratorTypes = + SmallVector(inputRank, utils::IteratorType::parallel); + SmallVector idMap(2, b.getMultiDimIdentityMap(inputRank)); + Value bilinearInterpolatedRoi = + b.create( + loc, outTensor.getType(), extractRoi, outTensor, + /*indexingMaps=*/idMap, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value retVal = bilinearInterpolate( + b, op, loc, outputSizeIntValues, extractRoi, + inputSizes, ValueRange{}, ""); + + b.create(loc, retVal); + }) + .getResult(0); + // Step 4: Sum pool over interpolated values. + Value sumPool, paddedInput; + + SmallVector kernelSizeIntValues = {/*intOne, intOne,*/ + scaleH, scaleW}; + SmallVector strideInts = {samplingRatioInt, + samplingRatioInt}; + SmallVector paddingInts = {0, 0}; + SmallVector dilationInts = {1, 1}; + SmallVector outTensorShape; + if (failed(createPoolingOp( + op, rewriter, bilinearInterpolatedRoi, + /*supportNonFPInput=*/true, false, + /*dimensionality=*/2, kernelSizeIntValues, strideInts, + paddingInts, dilationInts, + b.getZeroAttr(resultElementType), outTensorShape, + paddedInput, sumPool))) + op.emitError("unable to compute sumpool"); + + // Step 5: elementwise division by number of sampling points + // to compute avg pool. + Value outputTensor = b.create( + loc, getAsOpFoldResult(outTensorShape), resultElementType); + Value divisor = b.create(loc, scaleH, scaleW); + divisor = rewriter.create(loc, rewriter.getF32Type(), divisor); + Value avgPool = + b.create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/idMap, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value res = b.create(loc, args[0], divisor); + b.create(loc, res); + }) + .getResult(0); + SmallVector finalStrides(inputRank, oneAttr); + SmallVector finalOffsets = { + getAsOpFoldResult(iv0), getAsOpFoldResult(iv1), zeroAttr, + zeroAttr}; + SmallVector finalSizes = { + idxOne, idxOne, getAsOpFoldResult(pooledHIdx), getAsOpFoldResult(pooledWIdx)}; + SmallVector diagStrides(inputRank, oneAttr); + auto insert = b.create( + loc, avgPool, args[0], finalOffsets, finalSizes, + finalStrides); + b.create(loc, insert.getResult()); + }); + b.create(loc, res.getResult(0)); + }); + rewriter.replaceOpWithNewOp(op, resultType, resOut.getResult(0)); + return success(); + } +}; } // namespace void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( @@ -1665,4 +2120,5 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( typeConverter, context); patterns.add>( typeConverter, context); + patterns.add(typeConverter, context); }