From ff3a8c57f07f5bcec6dad8f2daf9bee7a2fa8958 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Fri, 7 Feb 2025 17:24:01 -0500 Subject: [PATCH 01/13] Generalize the PyTorch avg_pool linalg lowering algorithm for the case where count_include_pad = false. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 407 +++++++++++------- projects/pt1/e2e_testing/xfail_sets.py | 12 + .../torch_mlir_e2e_test/test_suite/pooling.py | 108 +++++ test/Conversion/TorchToLinalg/pooling.mlir | 130 ++++++ 4 files changed, 506 insertions(+), 151 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 90b5b2af77a8..288a1d9b5ee6 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -18,6 +18,7 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include using namespace mlir; using namespace mlir::torch; @@ -829,6 +830,251 @@ class ConvertAtenMaxUnpool3dOp final }; } // namespace +namespace { +// The following structures and the adsfdasf method +// are used to get the number of dimensions from the +// average pooling type at compile time. +template struct AtenAvgPoolTypeNumOfDims { + static constexpr int getNumOfDims() { return -1; } +}; +template <> struct AtenAvgPoolTypeNumOfDims { + static constexpr int getNumOfDims() { return 1; } +}; +template <> struct AtenAvgPoolTypeNumOfDims { + static constexpr int getNumOfDims() { return 2; } +}; +template <> struct AtenAvgPoolTypeNumOfDims { + static constexpr int getNumOfDims() { return 3; } +}; +template constexpr int getAvgPoolNumOfDims() { + return AtenAvgPoolTypeNumOfDims::getNumOfDims(); +} +} // namespace + +namespace { +// This structure, used solely in PoolSizeCalculator, provides +// the intermediate values for each dimension to compute the +// divisor of the average pooling operator. +struct PoolSizeValues { + int64_t SpatialDimsInt64; + int64_t DimSpatialInt; + Value InputSpatialDimValues; + Value IndexODim; + Value ODim; + Value DDim; + Value PadDim; + Value ODimDDim; + Value IDim0; + Value IDim; + Value IDim0KDim; + Value IDimPadDim; + Value IDim1; + Value IDim1IDims0; + Value IDim0Clamped; + Value IDim1Clamped; + Value IDim1_IDim0; +}; +} // namespace + +namespace { +// This is a helper class to create the pooling size value +// used in the divisor of the average pooling operator. +template class PoolSizeCalculator { +public: + PoolSizeCalculator(Value self, Value sumPool, + ConversionPatternRewriter &rewriter, Location loc); + + // The algorithm for computing the divisor with + // count_include_pad is manily based on pytorch + // implementation. The following code is comment + // with pytorch code. + // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 + // Dim below stands for spatial dimension. It replaces the + // height and width labels in variables. + Value getPoolSize(OpBuilder &b, SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts); + +private: + PoolSizeValues dims[NumOfDims]; + ConversionPatternRewriter &rewriterHandle; + Location location; +}; + +} // namespace + +template +PoolSizeCalculator::PoolSizeCalculator( + Value self, Value sumPool, ConversionPatternRewriter &rewriter, + Location loc) + : rewriterHandle(rewriter), location(loc) { + auto selfType = cast(self.getType()); + const int64_t selfRank = selfType.getRank(); + RankedTensorType sumPoolType = cast(sumPool.getType()); + const int64_t rank = sumPoolType.getRank(); + + // Store dimensions in this order: + // 0 => width, 1 => height, 2 => depth + for (int i = 0; i < NumOfDims; ++i) { + dims[i].SpatialDimsInt64 = toPositiveDim(-(i + 1), selfRank); + dims[i].InputSpatialDimValues = + getDimOp(rewriterHandle, location, self, dims[i].SpatialDimsInt64); + dims[i].DimSpatialInt = toPositiveDim(-(i + 1), rank); + } +}; + +template +Value PoolSizeCalculator::getPoolSize( + OpBuilder &b, SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts) { + Value poolSize; + + Value cstZero = rewriterHandle.create( + location, rewriterHandle.getI64IntegerAttr(0)); + + for (int i = 0; i < NumOfDims; ++i) { + dims[i].IndexODim = + b.create(location, /*value=*/dims[i].DimSpatialInt); + dims[i].ODim = castIndexToInt64(b, location, dims[i].IndexODim); + dims[i].DDim = rewriterHandle.create( + location, rewriterHandle.getI64IntegerAttr(strideInts[i])); + dims[i].PadDim = rewriterHandle.create( + location, rewriterHandle.getI64IntegerAttr(paddingInts[i])); + dims[i].ODimDDim = + b.create(location, dims[i].ODim, dims[i].DDim); + dims[i].IDim0 = + b.create(location, dims[i].ODimDDim, dims[i].PadDim); + dims[i].IDim = castIndexToInt64(b, location, dims[i].InputSpatialDimValues); + dims[i].IDim0KDim = b.create(location, dims[i].IDim0, + kernelSizeIntValues[i]); + dims[i].IDimPadDim = + b.create(location, dims[i].IDim, dims[i].PadDim); + dims[i].IDim1 = b.create(location, dims[i].IDim0KDim, + dims[i].IDimPadDim); + dims[i].IDim1IDims0 = + b.create(location, dims[i].IDim1, dims[i].IDim0); + + dims[i].IDim0Clamped = + b.create(location, dims[i].IDim0, cstZero); + dims[i].IDim1Clamped = + b.create(location, dims[i].IDim1, dims[i].IDim); + dims[i].IDim1_IDim0 = b.create( + location, dims[i].IDim1Clamped, dims[i].IDim0Clamped); + if (i == 0) { + poolSize = dims[0].IDim1_IDim0; + } else { + poolSize = + b.create(location, poolSize, dims[i].IDim1_IDim0); + } + } + return poolSize; +} + +// Creates the average pooling operation value when the +// count_include_pad parameter is equal to false. +template +static std::optional createAvgPoolValueCountIncludePadFalseCase( + bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter, Value self, Value sumPool, + Value outputTensor, Type resultType, + SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, + SmallVector &indexingMapsAvg, + SmallVector &iteratorTypesAvg) { + Location loc = op->getLoc(); + + constexpr int avgPoolDims = getAvgPoolNumOfDims(); + + bool noPadding = llvm::all_of(paddingInts, [](int64_t p) { return p == 0; }); + if (countIncludePad || noPadding) { + // These cases are not handled here. + return std::nullopt; + } + if (avgPoolDims < 1) { + return rewriter.notifyMatchFailure( + op, "Unexpected type. Only expected AtenAvgPool1dOp, AtenAvgPool2dOp, " + "and AtenAvgPool3dOp."); + } + + Type resultElementType = cast(resultType).getElementType(); + + PoolSizeCalculator poolSizeCalculator(self, sumPool, rewriter, + loc); + Value avgPool = + rewriter + .create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + auto poolSize = poolSizeCalculator.getPoolSize( + b, kernelSizeIntValues, strideInts, paddingInts); + // AtenAvgPool2/3dOp has an optional divisor_override + // attribute while AtenAvgPool1dOp does not. + if constexpr (avgPoolDims > 1) { + if (!isa(op.getDivisorOverride().getType())) + poolSize = adaptor.getDivisorOverride(); + } + Value divisor = + convertScalarToDtype(b, loc, poolSize, resultElementType); + Value avg; + if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + else if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + b.create(loc, avg); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, resultType, avgPool); + return success(); +} + +// Creates the average pooling operation value when the +// count_include_pad parameter is equal to true. +template +static LogicalResult createAvgPoolValueCountIncludePadTrueCase( + OpTy op, typename OpTy::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, Value self, Value sumPool, + Value outputTensor, Type resultType, + SmallVectorImpl &kernelSizeIntValues, + SmallVector &indexingMapsAvg, + SmallVector &iteratorTypesAvg) { + Location loc = op->getLoc(); + + Type resultElementType = cast(resultType).getElementType(); + + Value divisor = kernelSizeIntValues[0]; + for (uint32_t i = 1; i < kernelSizeIntValues.size(); ++i) { + divisor = + rewriter.create(loc, divisor, kernelSizeIntValues[i]); + } + if constexpr (!std::is_same()) { + divisor = isa(op.getDivisorOverride().getType()) + ? divisor + : adaptor.getDivisorOverride(); + } + divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); + + Value avgPool = + rewriter + .create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value avg; + if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + else if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + b.create(loc, avg); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, resultType, avgPool); + return success(); +} + namespace { template class ConvertAtenAvgPoolOp : public OpConversionPattern { @@ -892,159 +1138,18 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { 2, rewriter.getMultiDimIdentityMap(Dim + 2)); SmallVector iteratorTypesAvg( Dim + 2, utils::IteratorType::parallel); - Value avgPool; - Value divisor; - // Case1: AtenAvgPool1d/2dOp with countIncludePad=false support. - if constexpr (std::is_same()) { - auto selfType = cast(self.getType()); - const int64_t selfRank = selfType.getRank(); - int64_t wDim = toPositiveDim(-1, selfRank); - int64_t hDim = toPositiveDim(-2, selfRank); - Value inputHeight = getDimOp(rewriter, loc, self, hDim); - Value inputWidth = getDimOp(rewriter, loc, self, wDim); - RankedTensorType sumPoolType = cast(sumPool.getType()); - const int64_t rank = sumPoolType.getRank(); - int dimH = toPositiveDim(-2, rank); - int dimW = toPositiveDim(-1, rank); - avgPool = - rewriter - .create( - loc, outputTensor.getType(), sumPool, outputTensor, - /*indexingMaps=*/indexingMapsAvg, - /*iteratorTypes=*/iteratorTypesAvg, - [&](OpBuilder &b, Location loc, ValueRange args) { - // The algorithm for computing the divisor with - // count_include_pad is manily based on pytorch - // implementation. The following code is comment - // with pytorch code. - // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 - Value indexOh = - b.create(loc, /*value=*/dimH); - Value oh = castIndexToInt64(b, loc, indexOh); - Value indexOw = - b.create(loc, /*value=*/dimW); - Value ow = castIndexToInt64(b, loc, indexOw); - - // int64_t ih0 = oh * dH - padH; - Value dH = rewriter.create( - loc, rewriter.getI64IntegerAttr(strideInts[0])); - Value padH = rewriter.create( - loc, rewriter.getI64IntegerAttr(paddingInts[0])); - Value ohDH = b.create(loc, oh, dH); - Value ih0 = b.create(loc, ohDH, padH); - // int64_t iw0 = ow * dW - padW; - Value dW = rewriter.create( - loc, rewriter.getI64IntegerAttr(strideInts[1])); - Value padW = rewriter.create( - loc, rewriter.getI64IntegerAttr(paddingInts[1])); - Value owDW = b.create(loc, ow, dW); - Value iw0 = b.create(loc, owDW, padW); - // int64_t ih1 = std::min(ih0 + kH, input_height + padH); - Value ih = castIndexToInt64(b, loc, inputHeight); - Value ih0KH = b.create( - loc, ih0, kernelSizeIntValues[0]); - Value ihPadH = b.create(loc, ih, padH); - Value ih1 = b.create(loc, ih0KH, ihPadH); - // int64_t iw1 = std::min(iw0 + kW, input_width + padW); - Value iw = castIndexToInt64(b, loc, inputWidth); - Value iw0KW = b.create( - loc, iw0, kernelSizeIntValues[1]); - Value iwPadW = b.create(loc, iw, padW); - Value iw1 = b.create(loc, iw0KW, iwPadW); - // int64_t pool_size = (ih1 - ih0) * (iw1 - iw0); - Value ih1Ih0 = b.create(loc, ih1, ih0); - Value iw1Iw0 = b.create(loc, iw1, iw0); - Value poolSize = - b.create(loc, ih1Ih0, iw1Iw0); - // ih0 = std::max(ih0, 0); - Value cstZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value ih0Clamped = - b.create(loc, ih0, cstZero); - // iw0 = std::max(iw0, 0); - Value iw0Clamped = - b.create(loc, iw0, cstZero); - // ih1 = std::min(ih1, input_height); - Value ih1Clamped = b.create(loc, ih1, ih); - // iw1 = std::min(iw1, input_width); - Value iw1Clamped = b.create(loc, iw1, iw); - // if (divisor_override.has_value()) { - // divisor = divisor_override.value(); - // } else { - // if(count_include_pad) { - // divisor = pool_size; - // } else { - // divisor = (ih1 - ih0) * (iw1 - iw0); - // } - // } - if (countIncludePad) { - divisor = convertScalarToDtype(b, loc, poolSize, - resultElementType); - } else { - Value ih1_ih0 = - b.create(loc, ih1Clamped, ih0Clamped); - Value iw1_iw0 = - b.create(loc, iw1Clamped, iw0Clamped); - divisor = b.create(loc, ih1_ih0, iw1_iw0); - } - // AtenAvgPool2/3dOp has an optional divisor_override - // attribute while AtenAvgPool1dOp does not. - if constexpr (std::is_same()) { - if (!isa( - op.getDivisorOverride().getType())) - divisor = adaptor.getDivisorOverride(); - } - - divisor = convertScalarToDtype(b, loc, divisor, - resultElementType); - Value avg; - if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - else if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - b.create(loc, avg); - }) - .getResult(0); - rewriter.replaceOpWithNewOp(op, resultType, avgPool); - return success(); - } - // TODO: Add support for count_include_pad equal to `False` in - // AtenAvgPool1/3dOp. - if (!countIncludePad && - !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) { - return rewriter.notifyMatchFailure( - op, "unimplemented: count_include_pad is expected to be true for " - "AtenAvgPool3dOp"); - } + auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase( + countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor, + resultType, kernelSizeIntValues, strideInts, paddingInts, + indexingMapsAvg, iteratorTypesAvg); + if (divisorOpResult) + return *divisorOpResult; + + return createAvgPoolValueCountIncludePadTrueCase( + op, adaptor, rewriter, self, sumPool, outputTensor, resultType, + kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg); - // Case2: AtenAvgPool1/3dOp without count_include_pad equal to `False`. - divisor = kernelSizeIntValues[0]; - for (uint32_t i = 1; i < kernelSizeIntValues.size(); i++) { - divisor = - rewriter.create(loc, divisor, kernelSizeIntValues[i]); - } - if constexpr (!std::is_same()) { - divisor = isa(op.getDivisorOverride().getType()) - ? divisor - : adaptor.getDivisorOverride(); - } - divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); - avgPool = rewriter - .create( - loc, outputTensor.getType(), sumPool, outputTensor, - /*indexingMaps=*/indexingMapsAvg, - /*iteratorTypes=*/iteratorTypesAvg, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value avg; - if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - else if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - b.create(loc, avg); - }) - .getResult(0); - rewriter.replaceOpWithNewOp(op, resultType, avgPool); return success(); } }; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e433fabe2712..ccbecab01ffa 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1058,9 +1058,13 @@ "Aten_CastFloatModule_basic", "Aten_CastLongModule_basic", "AvgPool1dStaticModule_basic", + "AvgPool1dCountIncludePadFalseWithoutPadding_basic", + "AvgPool1dCountIncludePadFalse_basic", "AvgPool2dStaticModule_basic", "AvgPool2dCountIncludePadFalseStaticModule_basic", "AvgPool3dStaticModule_basic", + "AvgPool3dCountIncludePadFalse_basic", + "AvgPool3dCountIncludePadFalseWithoutPadding_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", "BaddbmmStaticModule_basic", @@ -3386,6 +3390,8 @@ "AtenKthvalueKeepDimModule_basic", "AtenKthvalueModule_basic", "AvgPool3dStaticModule_basic", + "AvgPool3dCountIncludePadFalse_basic", + "AvgPool3dCountIncludePadFalseWithoutPadding_basic", "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", "Conv_Transpose2dStaticModule_basic", @@ -3464,6 +3470,8 @@ "AvgPool1dFloatModule_basic", "AvgPool1dIntModule_basic", "AvgPool1dStaticModule_basic", + "AvgPool1dCountIncludePadFalseWithoutPadding_basic", + "AvgPool1dCountIncludePadFalse_basic", "AvgPool2dCeilModeTrueModule_basic", "AvgPool2dDivisorOverrideModule_basic", "AvgPool2dFloatModule_basic", @@ -3861,6 +3869,8 @@ "AtenKthvalueModule_basic", "AvgPool2dCountIncludePadFalseStaticModule_basic", "AvgPool3dStaticModule_basic", + "AvgPool3dCountIncludePadFalse_basic", + "AvgPool3dCountIncludePadFalseWithoutPadding_basic", "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", "Conv_Transpose2dStaticModule_basic", @@ -4014,6 +4024,8 @@ "AvgPool1dFloatModule_basic", "AvgPool1dIntModule_basic", "AvgPool1dStaticModule_basic", + "AvgPool1dCountIncludePadFalseWithoutPadding_basic", + "AvgPool1dCountIncludePadFalse_basic", "AvgPool2dCeilModeTrueModule_basic", "AvgPool2dDivisorOverrideModule_basic", "AvgPool2dFloatModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index e2eaa4cfd0fe..ead29dcb84f2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1460,6 +1460,66 @@ def AvgPool3dStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 4, 4, 4, low=-1)) +class AvgPool3dCountIncludePadFalse(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool3d( + kernel_size=[3, 3, 3], + stride=[1, 1, 1], + padding=[1, 1, 1], + ceil_mode=False, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([3, 3, 12, 12, 12], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool3dCountIncludePadFalse()) +def AvgPool3dCountIncludePadFalse_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3, 12, 12, 12, low=-1)) + + +class AvgPool3dCountIncludePadFalseWithoutPadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool3d( + kernel_size=[3, 3, 3], + stride=[1, 1, 1], + padding=[0, 0, 0], + ceil_mode=False, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([3, 3, 12, 12, 12], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool3dCountIncludePadFalseWithoutPadding() +) +def AvgPool3dCountIncludePadFalseWithoutPadding_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3, 12, 12, 12, low=-1)) + + # ============================================================================== @@ -1532,6 +1592,54 @@ def AvgPool1dStaticModule_basic(module, tu: TestUtils): module.forward(tu.randint(2, 4, 20, high=100)) +class AvgPool1dCountIncludePadFalseWithoutPadding(torch.nn.Module): + def __init__(self): + super().__init__() + self.ap1d = torch.nn.AvgPool1d( + kernel_size=3, stride=1, padding=0, ceil_mode=False, count_include_pad=False + ) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap1d(x) + + +@register_test_case( + module_factory=lambda: AvgPool1dCountIncludePadFalseWithoutPadding() +) +def AvgPool1dCountIncludePadFalseWithoutPadding_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 20)) + + +class AvgPool1dCountIncludePadFalse(torch.nn.Module): + def __init__(self): + super().__init__() + self.ap1d = torch.nn.AvgPool1d( + kernel_size=3, stride=1, padding=1, ceil_mode=False, count_include_pad=False + ) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap1d(x) + + +@register_test_case(module_factory=lambda: AvgPool1dCountIncludePadFalse()) +def AvgPool1dCountIncludePadFalse_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 20)) + + # ============================================================================== diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 558c50c4f08f..53faa1d37d4f 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -95,3 +95,133 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch. // CHECK: } -> tensor return %4 : !torch.vtensor<[?,?,?,?,?],f32> } + +// CHECK-LABEL: func @forward_avg_pool2d +func.func @forward_avg_pool2d(%arg0: !torch.vtensor<[1,3,64,56],f32>) -> !torch.vtensor<[1,3, 61,27],f32> { + // CHECK: linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<[1, 2]> : vector<2xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x3x64x58xf32>, tensor<4x5xf32>) outs(%[[OUT1:.*]] : tensor<1x3x61x27xf32>) -> tensor<1x3x61x27xf32> + // CHECK: linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x61x27xf32>) outs(%[[OUT2:.*]] : tensor<1x3x61x27xf32>) + // CHECK-NEXT: ^bb0(%[[BIIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-NEXT: %[[TMP1:.*]] = arith.divf %[[BIIN1:.*]], %[[CONST1:.*]] : f32 + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x3x61x27xf32> + %none = torch.constant.none + %false = torch.constant.bool false + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[1,3,64,56],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,61,27],f32> + return %3 : !torch.vtensor<[1,3,61,27],f32> +} + +// CHECK-LABEL: func @forward_avg_pool2d_countincludepad_false +func.func @forward_avg_pool2d_countincludepad_false(%arg0: !torch.vtensor<[1,3,64,56],f32>) -> !torch.vtensor<[1,3, 61,27],f32> { + // CHECK: linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<[1, 2]> : vector<2xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x3x64x58xf32>, tensor<4x5xf32>) outs(%[[OUT1:.*]] : tensor<1x3x61x27xf32>) -> tensor<1x3x61x27xf32> + // CHECK: linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x61x27xf32>) outs(%[[OUT2:.*]] : tensor<1x3x61x27xf32>) + // CHECK-NEXT: ^bb0(%[[BIIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-COUNT-4: arith.minsi + // CHECK-COUNT-1: arith.divf + // CHECK: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x3x61x27xf32> + %none = torch.constant.none + %false = torch.constant.bool false + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %false, %none : !torch.vtensor<[1,3,64,56],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,61,27],f32> + return %3 : !torch.vtensor<[1,3,61,27],f32> +} + +// CHECK-LABEL: func @forward_avg_pool3d +func.func @forward_avg_pool3d(%arg0: !torch.vtensor<[1,3,7,64,56],f32>) -> !torch.vtensor<[1,3,4,31,54],f32> { + // CHECK: linalg.pooling_ndhwc_sum {dilations = dense<1> : vector<3xi64>, strides = dense<[1, 2, 1]> : vector<3xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x7x66x58x3xf32>, tensor<4x5x5xf32>) outs(%[[OUT1:.*]] : tensor<1x4x31x54x3xf32>) -> tensor<1x4x31x54x3xf32> + // CHECK: linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x4x31x54xf32>) outs(%[[OUT2:.*]] : tensor<1x3x4x31x54xf32>) + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-NEXT: %[[TMP1:.*]] = arith.divf %[[BIN1:.*]], %[[CONST1:.*]] : f32 + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x3x4x31x54xf32> + %none = torch.constant.none + %false = torch.constant.bool false + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %0 = torch.prim.ListConstruct %int4, %int5, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool3d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[1,3,7,64,56],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,4,31,54],f32> + return %3 : !torch.vtensor<[1,3,4,31,54],f32> +} + +// CHECK-LABEL: func @forward_avg_pool3dd_countincludepad_false +func.func @forward_avg_pool3dd_countincludepad_false(%arg0: !torch.vtensor<[1,3,7,64,56],f32>) -> !torch.vtensor<[1,3,4,31,54],f32> { + // CHECK: linalg.pooling_ndhwc_sum {dilations = dense<1> : vector<3xi64>, strides = dense<[1, 2, 1]> : vector<3xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x7x66x58x3xf32>, tensor<4x5x5xf32>) outs(%[[OUT1:.*]] : tensor<1x4x31x54x3xf32>) -> tensor<1x4x31x54x3xf32> + // CHECK: linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x4x31x54xf32>) outs(%[[OUT2:.*]] : tensor<1x3x4x31x54xf32>) + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-COUNT-6: arith.minsi + // CHECK-COUNT-1: arith.divf + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x3x4x31x54xf32> + %none = torch.constant.none + %false = torch.constant.bool false + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %0 = torch.prim.ListConstruct %int4, %int5, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool3d %arg0, %0, %1, %2, %false, %false, %none : !torch.vtensor<[1,3,7,64,56],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,4,31,54],f32> + return %3 : !torch.vtensor<[1,3,4,31,54],f32> +} + +// CHECK-LABEL: func @forward_avg_pool1d +func.func @forward_avg_pool1d(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,12],f32> { + // CHECK: linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%[[IN1:.*]], %[[IN2:.*]] : tensor<1x512x12xf32>, tensor<1xf32>) outs(%[[OUT1:.*]] : tensor<1x512x12xf32>) -> tensor<1x512x12xf32> + // CHECK: linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[IN3:.*]] : tensor<1x512x12xf32>) outs(%[[OUT2:.*]] : tensor<1x512x12xf32> + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-NEXT: %[[TMP1:.*]] = arith.divf %[[BIN1:.*]], %[[CONST1:.*]] : f32 + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x512x12xf32> + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %true = torch.constant.bool true + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %true : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,12],f32> + return %3 : !torch.vtensor<[1,512,12],f32> +} + +// CHECK-LABEL: func @forward_avg_pool1d_countincludepad_false +func.func @forward_avg_pool1d_countincludepad_false(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,12],f32> { + // CHECK: linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%[[IN1:.*]], %[[IN2:.*]] : tensor<1x512x12xf32>, tensor<1xf32>) outs(%[[OUT1:.*]] : tensor<1x512x12xf32>) -> tensor<1x512x12xf32> + // CHECK: linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[IN3:.*]] : tensor<1x512x12xf32>) outs(%[[OUT2:.*]] : tensor<1x512x12xf32> + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-COUNT-2: arith.minsi + // CHECK-COUNT-1: arith.divf + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x512x12xf32> + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %false : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,12],f32> + return %3 : !torch.vtensor<[1,512,12],f32> +} From 20ff9d35e60ad4cdbedf9e91e431bb20dff1247e Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Fri, 7 Feb 2025 20:04:22 -0500 Subject: [PATCH 02/13] Remove semicolon accidentally left in a function definition. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 288a1d9b5ee6..6f8ca853e6fc 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -921,7 +921,7 @@ PoolSizeCalculator::PoolSizeCalculator( getDimOp(rewriterHandle, location, self, dims[i].SpatialDimsInt64); dims[i].DimSpatialInt = toPositiveDim(-(i + 1), rank); } -}; +} template Value PoolSizeCalculator::getPoolSize( From aafda93085400be562924b8413ef3ffffc5186a5 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Tue, 11 Feb 2025 09:48:16 -0500 Subject: [PATCH 03/13] Remove dynamic dimensions in e2e test; they were not really needed for this test. --- projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index ead29dcb84f2..e308e38bc1e3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1603,7 +1603,7 @@ def __init__(self): @annotate_args( [ None, - ([-1, -1, -1], torch.float32, True), + ([3, 4, 20], torch.float32, True), ] ) def forward(self, x): @@ -1628,7 +1628,7 @@ def __init__(self): @annotate_args( [ None, - ([-1, -1, -1], torch.float32, True), + ([3, 4, 20], torch.float32, True), ] ) def forward(self, x): From c61bf8e6f6b40e6cf354ae7936af9177a78aa5f0 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Wed, 12 Feb 2025 16:26:45 -0500 Subject: [PATCH 04/13] Fix xfail_sets.py file - filtering too much. --- projects/pt1/e2e_testing/xfail_sets.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ccbecab01ffa..e6d24c9d5908 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1058,13 +1058,9 @@ "Aten_CastFloatModule_basic", "Aten_CastLongModule_basic", "AvgPool1dStaticModule_basic", - "AvgPool1dCountIncludePadFalseWithoutPadding_basic", - "AvgPool1dCountIncludePadFalse_basic", "AvgPool2dStaticModule_basic", "AvgPool2dCountIncludePadFalseStaticModule_basic", "AvgPool3dStaticModule_basic", - "AvgPool3dCountIncludePadFalse_basic", - "AvgPool3dCountIncludePadFalseWithoutPadding_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", "BaddbmmStaticModule_basic", @@ -3869,8 +3865,6 @@ "AtenKthvalueModule_basic", "AvgPool2dCountIncludePadFalseStaticModule_basic", "AvgPool3dStaticModule_basic", - "AvgPool3dCountIncludePadFalse_basic", - "AvgPool3dCountIncludePadFalseWithoutPadding_basic", "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", "Conv_Transpose2dStaticModule_basic", @@ -4024,8 +4018,6 @@ "AvgPool1dFloatModule_basic", "AvgPool1dIntModule_basic", "AvgPool1dStaticModule_basic", - "AvgPool1dCountIncludePadFalseWithoutPadding_basic", - "AvgPool1dCountIncludePadFalse_basic", "AvgPool2dCeilModeTrueModule_basic", "AvgPool2dDivisorOverrideModule_basic", "AvgPool2dFloatModule_basic", From d06657b32c7dd628eaf65098545b3e23eda95bd2 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Thu, 13 Feb 2025 06:52:13 -0500 Subject: [PATCH 05/13] One more fix to xfail_sets.py --- projects/pt1/e2e_testing/xfail_sets.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e6d24c9d5908..3f8b58e6be62 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1989,7 +1989,11 @@ "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", + "AvgPool1dCountIncludePadFalseWithoutPadding_basic", + "AvgPool1dCountIncludePadFalse_basic", "AvgPool2dCountIncludePadFalseStaticModule_basic", + "AvgPool3dCountIncludePadFalseWithoutPadding_basic", + "AvgPool3dCountIncludePadFalse_basic", "TensorSplitSections_GetItemModule_basic", "TensorSplitSections_ListUnpackModule_basic", "Atleast1dModule0dInput_basic", @@ -3466,8 +3470,8 @@ "AvgPool1dFloatModule_basic", "AvgPool1dIntModule_basic", "AvgPool1dStaticModule_basic", - "AvgPool1dCountIncludePadFalseWithoutPadding_basic", - "AvgPool1dCountIncludePadFalse_basic", + # "AvgPool1dCountIncludePadFalseWithoutPadding_basic", + # "AvgPool1dCountIncludePadFalse_basic", "AvgPool2dCeilModeTrueModule_basic", "AvgPool2dDivisorOverrideModule_basic", "AvgPool2dFloatModule_basic", From cfae3041210344d876b2b7b250b7e79687b8d6d3 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Thu, 13 Feb 2025 07:20:47 -0500 Subject: [PATCH 06/13] Forgot to remove commented code. --- projects/pt1/e2e_testing/xfail_sets.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3f8b58e6be62..5fe29994f032 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3470,8 +3470,6 @@ "AvgPool1dFloatModule_basic", "AvgPool1dIntModule_basic", "AvgPool1dStaticModule_basic", - # "AvgPool1dCountIncludePadFalseWithoutPadding_basic", - # "AvgPool1dCountIncludePadFalse_basic", "AvgPool2dCeilModeTrueModule_basic", "AvgPool2dDivisorOverrideModule_basic", "AvgPool2dFloatModule_basic", From a98a5bfa4a25d21d0a9c99fabb6354d0358fb7d3 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Fri, 14 Feb 2025 12:16:35 -0500 Subject: [PATCH 07/13] Addressing feedback from Sayan (sahas3). --- lib/Conversion/TorchToLinalg/Pooling.cpp | 115 ++++++++---------- .../torch_mlir_e2e_test/test_suite/pooling.py | 8 +- 2 files changed, 54 insertions(+), 69 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 6f8ca853e6fc..0dcd4b6317dd 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -831,7 +831,7 @@ class ConvertAtenMaxUnpool3dOp final } // namespace namespace { -// The following structures and the adsfdasf method +// The following structures and the getNumOfDims method // are used to get the number of dimensions from the // average pooling type at compile time. template struct AtenAvgPoolTypeNumOfDims { @@ -851,31 +851,6 @@ template constexpr int getAvgPoolNumOfDims() { } } // namespace -namespace { -// This structure, used solely in PoolSizeCalculator, provides -// the intermediate values for each dimension to compute the -// divisor of the average pooling operator. -struct PoolSizeValues { - int64_t SpatialDimsInt64; - int64_t DimSpatialInt; - Value InputSpatialDimValues; - Value IndexODim; - Value ODim; - Value DDim; - Value PadDim; - Value ODimDDim; - Value IDim0; - Value IDim; - Value IDim0KDim; - Value IDimPadDim; - Value IDim1; - Value IDim1IDims0; - Value IDim0Clamped; - Value IDim1Clamped; - Value IDim1_IDim0; -}; -} // namespace - namespace { // This is a helper class to create the pooling size value // used in the divisor of the average pooling operator. @@ -885,7 +860,7 @@ template class PoolSizeCalculator { ConversionPatternRewriter &rewriter, Location loc); // The algorithm for computing the divisor with - // count_include_pad is manily based on pytorch + // count_include_pad equal is mainly based on pytorch // implementation. The following code is comment // with pytorch code. // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 @@ -896,8 +871,8 @@ template class PoolSizeCalculator { SmallVectorImpl &paddingInts); private: - PoolSizeValues dims[NumOfDims]; - ConversionPatternRewriter &rewriterHandle; + int64_t DimSizeFromSumPoolType[NumOfDims]; + Value InputSpatialDimValues[NumOfDims]; Location location; }; @@ -907,7 +882,7 @@ template PoolSizeCalculator::PoolSizeCalculator( Value self, Value sumPool, ConversionPatternRewriter &rewriter, Location loc) - : rewriterHandle(rewriter), location(loc) { + : location(loc) { auto selfType = cast(self.getType()); const int64_t selfRank = selfType.getRank(); RankedTensorType sumPoolType = cast(sumPool.getType()); @@ -916,10 +891,10 @@ PoolSizeCalculator::PoolSizeCalculator( // Store dimensions in this order: // 0 => width, 1 => height, 2 => depth for (int i = 0; i < NumOfDims; ++i) { - dims[i].SpatialDimsInt64 = toPositiveDim(-(i + 1), selfRank); - dims[i].InputSpatialDimValues = - getDimOp(rewriterHandle, location, self, dims[i].SpatialDimsInt64); - dims[i].DimSpatialInt = toPositiveDim(-(i + 1), rank); + int64_t DimSizeFromSelfType = toPositiveDim(-(i + 1), selfRank); + InputSpatialDimValues[i] = + getDimOp(rewriter, location, self, DimSizeFromSelfType); + DimSizeFromSumPoolType[i] = toPositiveDim(-(i + 1), rank); } } @@ -930,42 +905,52 @@ Value PoolSizeCalculator::getPoolSize( SmallVectorImpl &paddingInts) { Value poolSize; - Value cstZero = rewriterHandle.create( - location, rewriterHandle.getI64IntegerAttr(0)); + Value cstZero = b.create(location, b.getI64IntegerAttr(0)); for (int i = 0; i < NumOfDims; ++i) { - dims[i].IndexODim = - b.create(location, /*value=*/dims[i].DimSpatialInt); - dims[i].ODim = castIndexToInt64(b, location, dims[i].IndexODim); - dims[i].DDim = rewriterHandle.create( - location, rewriterHandle.getI64IntegerAttr(strideInts[i])); - dims[i].PadDim = rewriterHandle.create( - location, rewriterHandle.getI64IntegerAttr(paddingInts[i])); - dims[i].ODimDDim = - b.create(location, dims[i].ODim, dims[i].DDim); - dims[i].IDim0 = - b.create(location, dims[i].ODimDDim, dims[i].PadDim); - dims[i].IDim = castIndexToInt64(b, location, dims[i].InputSpatialDimValues); - dims[i].IDim0KDim = b.create(location, dims[i].IDim0, - kernelSizeIntValues[i]); - dims[i].IDimPadDim = - b.create(location, dims[i].IDim, dims[i].PadDim); - dims[i].IDim1 = b.create(location, dims[i].IDim0KDim, - dims[i].IDimPadDim); - dims[i].IDim1IDims0 = - b.create(location, dims[i].IDim1, dims[i].IDim0); - - dims[i].IDim0Clamped = - b.create(location, dims[i].IDim0, cstZero); - dims[i].IDim1Clamped = - b.create(location, dims[i].IDim1, dims[i].IDim); - dims[i].IDim1_IDim0 = b.create( - location, dims[i].IDim1Clamped, dims[i].IDim0Clamped); + // See the link below for the PyTorch implementation where + // this is derived from: + // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 + // Dim below stands for spatial dimension. Prior to the February + // 2025 change, these variables used "height" and "width" (or + // "h" and "w") in these intermediate variables instead of "Dim". + Value IndexODim; + Value ODim; + Value DDim; + Value PadDim; + Value ODimDDim; + Value IDim0; + Value IDim; + Value IDim0KDim; + Value IDimPadDim; + Value IDim1; + Value IDim0Clamped; + Value IDim1Clamped; + Value IDim1_IDim0_Clamped; + IndexODim = b.create(location, + /*value=*/DimSizeFromSumPoolType[i]); + ODim = castIndexToInt64(b, location, IndexODim); + DDim = b.create(location, + b.getI64IntegerAttr(strideInts[i])); + PadDim = b.create(location, + b.getI64IntegerAttr(paddingInts[i])); + ODimDDim = b.create(location, ODim, DDim); + IDim0 = b.create(location, ODimDDim, PadDim); + IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]); + IDim0KDim = + b.create(location, IDim0, kernelSizeIntValues[i]); + IDimPadDim = b.create(location, IDim, PadDim); + IDim1 = b.create(location, IDim0KDim, IDimPadDim); + + IDim0Clamped = b.create(location, IDim0, cstZero); + IDim1Clamped = b.create(location, IDim1, IDim); + IDim1_IDim0_Clamped = + b.create(location, IDim1Clamped, IDim0Clamped); if (i == 0) { - poolSize = dims[0].IDim1_IDim0; + poolSize = IDim1_IDim0_Clamped; } else { poolSize = - b.create(location, poolSize, dims[i].IDim1_IDim0); + b.create(location, poolSize, IDim1_IDim0_Clamped); } } return poolSize; diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index e308e38bc1e3..ce7d2e2bc42e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1464,7 +1464,7 @@ class AvgPool3dCountIncludePadFalse(torch.nn.Module): def __init__(self): super().__init__() - self.ap2d = torch.nn.AvgPool3d( + self.ap3d = torch.nn.AvgPool3d( kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[1, 1, 1], @@ -1481,7 +1481,7 @@ def __init__(self): ] ) def forward(self, x): - return self.ap2d(x) + return self.ap3d(x) @register_test_case(module_factory=lambda: AvgPool3dCountIncludePadFalse()) @@ -1493,7 +1493,7 @@ class AvgPool3dCountIncludePadFalseWithoutPadding(torch.nn.Module): def __init__(self): super().__init__() - self.ap2d = torch.nn.AvgPool3d( + self.ap3d = torch.nn.AvgPool3d( kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[0, 0, 0], @@ -1510,7 +1510,7 @@ def __init__(self): ] ) def forward(self, x): - return self.ap2d(x) + return self.ap3d(x) @register_test_case( From 8cb977c87ac1e0ba53308cedd9a8eb1dd15be00d Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Fri, 14 Feb 2025 16:07:25 -0500 Subject: [PATCH 08/13] Inlining local variable declarations with definitions. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 46 +++++++++--------------- 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 0dcd4b6317dd..f3514961e285 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -914,37 +914,25 @@ Value PoolSizeCalculator::getPoolSize( // Dim below stands for spatial dimension. Prior to the February // 2025 change, these variables used "height" and "width" (or // "h" and "w") in these intermediate variables instead of "Dim". - Value IndexODim; - Value ODim; - Value DDim; - Value PadDim; - Value ODimDDim; - Value IDim0; - Value IDim; - Value IDim0KDim; - Value IDimPadDim; - Value IDim1; - Value IDim0Clamped; - Value IDim1Clamped; - Value IDim1_IDim0_Clamped; - IndexODim = b.create(location, - /*value=*/DimSizeFromSumPoolType[i]); - ODim = castIndexToInt64(b, location, IndexODim); - DDim = b.create(location, - b.getI64IntegerAttr(strideInts[i])); - PadDim = b.create(location, - b.getI64IntegerAttr(paddingInts[i])); - ODimDDim = b.create(location, ODim, DDim); - IDim0 = b.create(location, ODimDDim, PadDim); - IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]); - IDim0KDim = + Value IndexODim = + b.create(location, + /*value=*/DimSizeFromSumPoolType[i]); + Value ODim = castIndexToInt64(b, location, IndexODim); + Value DDim = b.create( + location, b.getI64IntegerAttr(strideInts[i])); + Value PadDim = b.create( + location, b.getI64IntegerAttr(paddingInts[i])); + Value ODimDDim = b.create(location, ODim, DDim); + Value IDim0 = b.create(location, ODimDDim, PadDim); + Value IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]); + Value IDim0KDim = b.create(location, IDim0, kernelSizeIntValues[i]); - IDimPadDim = b.create(location, IDim, PadDim); - IDim1 = b.create(location, IDim0KDim, IDimPadDim); + Value IDimPadDim = b.create(location, IDim, PadDim); + Value IDim1 = b.create(location, IDim0KDim, IDimPadDim); - IDim0Clamped = b.create(location, IDim0, cstZero); - IDim1Clamped = b.create(location, IDim1, IDim); - IDim1_IDim0_Clamped = + Value IDim0Clamped = b.create(location, IDim0, cstZero); + Value IDim1Clamped = b.create(location, IDim1, IDim); + Value IDim1_IDim0_Clamped = b.create(location, IDim1Clamped, IDim0Clamped); if (i == 0) { poolSize = IDim1_IDim0_Clamped; From a56df76451ae3c4e1a2612cbe103ba9f3451b9b4 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Wed, 19 Feb 2025 11:36:07 -0500 Subject: [PATCH 09/13] Addressing second round of Sayan's feedback. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 46 +++++++++++++----------- projects/pt1/e2e_testing/xfail_sets.py | 4 --- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index f3514961e285..51930f1410bd 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -905,7 +905,8 @@ Value PoolSizeCalculator::getPoolSize( SmallVectorImpl &paddingInts) { Value poolSize; - Value cstZero = b.create(location, b.getI64IntegerAttr(0)); + Value cstZero = + b.createOrFold(location, b.getI64IntegerAttr(0)); for (int i = 0; i < NumOfDims; ++i) { // See the link below for the PyTorch implementation where @@ -915,30 +916,32 @@ Value PoolSizeCalculator::getPoolSize( // 2025 change, these variables used "height" and "width" (or // "h" and "w") in these intermediate variables instead of "Dim". Value IndexODim = - b.create(location, - /*value=*/DimSizeFromSumPoolType[i]); + b.createOrFold(location, + /*value=*/DimSizeFromSumPoolType[i]); Value ODim = castIndexToInt64(b, location, IndexODim); - Value DDim = b.create( + Value DDim = b.createOrFold( location, b.getI64IntegerAttr(strideInts[i])); - Value PadDim = b.create( + Value PadDim = b.createOrFold( location, b.getI64IntegerAttr(paddingInts[i])); - Value ODimDDim = b.create(location, ODim, DDim); - Value IDim0 = b.create(location, ODimDDim, PadDim); + Value ODimDDim = b.createOrFold(location, ODim, DDim); + Value IDim0 = b.createOrFold(location, ODimDDim, PadDim); Value IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]); Value IDim0KDim = - b.create(location, IDim0, kernelSizeIntValues[i]); - Value IDimPadDim = b.create(location, IDim, PadDim); - Value IDim1 = b.create(location, IDim0KDim, IDimPadDim); - - Value IDim0Clamped = b.create(location, IDim0, cstZero); - Value IDim1Clamped = b.create(location, IDim1, IDim); + b.createOrFold(location, IDim0, kernelSizeIntValues[i]); + Value IDimPadDim = b.createOrFold(location, IDim, PadDim); + Value IDim1 = + b.createOrFold(location, IDim0KDim, IDimPadDim); + + Value IDim0Clamped = + b.createOrFold(location, IDim0, cstZero); + Value IDim1Clamped = b.createOrFold(location, IDim1, IDim); Value IDim1_IDim0_Clamped = - b.create(location, IDim1Clamped, IDim0Clamped); + b.createOrFold(location, IDim1Clamped, IDim0Clamped); if (i == 0) { poolSize = IDim1_IDim0_Clamped; } else { - poolSize = - b.create(location, poolSize, IDim1_IDim0_Clamped); + poolSize = b.createOrFold(location, poolSize, + IDim1_IDim0_Clamped); } } return poolSize; @@ -993,10 +996,10 @@ static std::optional createAvgPoolValueCountIncludePadFalseCase( convertScalarToDtype(b, loc, poolSize, resultElementType); Value avg; if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); + avg = b.createOrFold(loc, args[0], divisor); else if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - b.create(loc, avg); + avg = b.createOrFold(loc, args[0], divisor); + b.createOrFold(loc, avg); }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, avgPool); @@ -1019,9 +1022,10 @@ static LogicalResult createAvgPoolValueCountIncludePadTrueCase( Value divisor = kernelSizeIntValues[0]; for (uint32_t i = 1; i < kernelSizeIntValues.size(); ++i) { - divisor = - rewriter.create(loc, divisor, kernelSizeIntValues[i]); + divisor = rewriter.createOrFold(loc, divisor, + kernelSizeIntValues[i]); } + // Only average pooling 2D/3D have optional divisor override. if constexpr (!std::is_same()) { divisor = isa(op.getDivisorOverride().getType()) ? divisor diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 5fe29994f032..90f1e916a86b 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1989,11 +1989,7 @@ "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", - "AvgPool1dCountIncludePadFalseWithoutPadding_basic", - "AvgPool1dCountIncludePadFalse_basic", "AvgPool2dCountIncludePadFalseStaticModule_basic", - "AvgPool3dCountIncludePadFalseWithoutPadding_basic", - "AvgPool3dCountIncludePadFalse_basic", "TensorSplitSections_GetItemModule_basic", "TensorSplitSections_ListUnpackModule_basic", "Atleast1dModule0dInput_basic", From c7300d1e5ee83633fbe5ee5d499872d3da0a04e1 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Wed, 19 Feb 2025 13:26:26 -0500 Subject: [PATCH 10/13] Extra space between statements. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 51930f1410bd..6ce86d468894 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -1025,6 +1025,7 @@ static LogicalResult createAvgPoolValueCountIncludePadTrueCase( divisor = rewriter.createOrFold(loc, divisor, kernelSizeIntValues[i]); } + // Only average pooling 2D/3D have optional divisor override. if constexpr (!std::is_same()) { divisor = isa(op.getDivisorOverride().getType()) From d48b1582041b2ae6c1a5d170a69999e3c870c43b Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Wed, 19 Feb 2025 13:28:41 -0500 Subject: [PATCH 11/13] Reverting added space. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 6ce86d468894..51930f1410bd 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -1025,7 +1025,6 @@ static LogicalResult createAvgPoolValueCountIncludePadTrueCase( divisor = rewriter.createOrFold(loc, divisor, kernelSizeIntValues[i]); } - // Only average pooling 2D/3D have optional divisor override. if constexpr (!std::is_same()) { divisor = isa(op.getDivisorOverride().getType()) From b76f3835c7ed123e138ece4d64d875d582ecc10a Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Wed, 19 Feb 2025 13:34:25 -0500 Subject: [PATCH 12/13] Adding space back. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 51930f1410bd..6ce86d468894 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -1025,6 +1025,7 @@ static LogicalResult createAvgPoolValueCountIncludePadTrueCase( divisor = rewriter.createOrFold(loc, divisor, kernelSizeIntValues[i]); } + // Only average pooling 2D/3D have optional divisor override. if constexpr (!std::is_same()) { divisor = isa(op.getDivisorOverride().getType()) From 3b0eb6788bf3f2602e1d15f1b7e70705609d6420 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Wed, 19 Feb 2025 14:11:14 -0500 Subject: [PATCH 13/13] Merging recent changes. --- externals/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index a854c266b984..5d6d982df61d 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit a854c266b98468ad4479a7d3c56a3fa76437e30d +Subproject commit 5d6d982df61d16b6d498e6d59dd91c059679d3d8