diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 45268452a992..9c4fb4fc8773 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -856,7 +856,7 @@ namespace { // used in the divisor of the average pooling operator. template class PoolSizeCalculator { public: - PoolSizeCalculator(Value self, Value sumPool, + PoolSizeCalculator(Value self, Value sumPool, bool countIncludePad, ConversionPatternRewriter &rewriter, Location loc); // The algorithm for computing the divisor with @@ -871,36 +871,37 @@ template class PoolSizeCalculator { SmallVectorImpl &paddingInts); private: - int64_t DimSizeFromSumPoolType[NumOfDims]; - Value InputSpatialDimValues[NumOfDims]; + int64_t SumPoolTypeDimIndex[NumOfDims]; + Value InputSpatialDimSizes[NumOfDims]; Location location; + bool isCountIncludePad; }; } // namespace template PoolSizeCalculator::PoolSizeCalculator( - Value self, Value sumPool, ConversionPatternRewriter &rewriter, - Location loc) - : location(loc) { + Value self, Value sumPool, bool countIncludePad, + ConversionPatternRewriter &rewriter, Location loc) + : location(loc), isCountIncludePad(countIncludePad) { auto selfType = cast(self.getType()); const int64_t selfRank = selfType.getRank(); RankedTensorType sumPoolType = cast(sumPool.getType()); const int64_t rank = sumPoolType.getRank(); // Store dimensions in this order: - // 0 => width, 1 => height, 2 => depth + // 0 => depth, 1 => height, 2 => width for (int i = 0; i < NumOfDims; ++i) { - int64_t DimSizeFromSelfType = toPositiveDim(-(i + 1), selfRank); - InputSpatialDimValues[i] = - getDimOp(rewriter, location, self, DimSizeFromSelfType); - DimSizeFromSumPoolType[i] = toPositiveDim(-(i + 1), rank); + int64_t inputSpatialDimIndex = toPositiveDim(-(i + 1), selfRank); + InputSpatialDimSizes[NumOfDims - i - 1] = + getDimOp(rewriter, location, self, inputSpatialDimIndex); + SumPoolTypeDimIndex[NumOfDims - i - 1] = toPositiveDim(-(i + 1), rank); } } template Value PoolSizeCalculator::getPoolSize( - OpBuilder &b, SmallVectorImpl &kernelSizeIntValues, + OpBuilder &b, SmallVectorImpl &kernelDimSizes, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts) { Value poolSize; @@ -915,9 +916,10 @@ Value PoolSizeCalculator::getPoolSize( // Dim below stands for spatial dimension. Prior to the February 2025 // change, these variables used "height" and "width" (or "h" and "w") // in these intermediate variables instead of "Dim". + Value IndexODim = b.create(location, - /*value=*/DimSizeFromSumPoolType[i]); + /*value=*/SumPoolTypeDimIndex[i]); Value ODim = castIndexToInt64(b, location, IndexODim); Value DDim = b.createOrFold( location, b.getI64IntegerAttr(strideInts[i])); @@ -925,9 +927,9 @@ Value PoolSizeCalculator::getPoolSize( location, b.getI64IntegerAttr(paddingInts[i])); Value ODimDDim = b.createOrFold(location, ODim, DDim); Value IDim0 = b.createOrFold(location, ODimDDim, PadDim); - Value IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]); + Value IDim = castIndexToInt64(b, location, InputSpatialDimSizes[i]); Value IDim0KDim = - b.createOrFold(location, IDim0, kernelSizeIntValues[i]); + b.createOrFold(location, IDim0, kernelDimSizes[i]); Value IDimPadDim = b.createOrFold(location, IDim, PadDim); Value IDim1 = b.createOrFold(location, IDim0KDim, IDimPadDim); @@ -937,11 +939,15 @@ Value PoolSizeCalculator::getPoolSize( Value IDim1Clamped = b.createOrFold(location, IDim1, IDim); Value IDim1_IDim0_Clamped = b.createOrFold(location, IDim1Clamped, IDim0Clamped); + + Value poolSizeDim = + !isCountIncludePad + ? IDim1_IDim0_Clamped + : b.createOrFold(location, IDim1, IDim0); if (i == 0) { - poolSize = IDim1_IDim0_Clamped; + poolSize = poolSizeDim; } else { - poolSize = b.createOrFold(location, poolSize, - IDim1_IDim0_Clamped); + poolSize = b.createOrFold(location, poolSize, poolSizeDim); } } return poolSize; @@ -957,26 +963,35 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; - // Creates the average pooling operation value when the - // count_include_pad parameter is equal to false. - static std::optional - createAvgPoolValueCountIncludePadFalseCase( - bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter, Value self, Value sumPool, - Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + // If the condition below is true, the divisor total must subtract the + // elements not counted (clamped divisor count). If false, the divisor + // is just the product of kernel dimensions. + static bool + doesAvgPoolDivisorNeedsClamping(bool ceilMode, bool countIncludePad, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts); + + // Creates the average pooling operation value with a clamped + // divisor. The clamped divisor is the product of kernel + // dimensions minus the elements not counted; e.g., padding + // and ceiling mode implicit padding. + static LogicalResult createAveragePoolValueWithClampedDivisor( + bool ceilMode, bool countIncludePad, OpTy op, + typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter, + Value self, Value sumPool, Value outputTensor, Type resultType, + SmallVectorImpl &kernelDimSizes, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, SmallVector &indexingMapsAvg, SmallVector &iteratorTypesAvg); - // Creates the average pooling operation value when the - // count_include_pad parameter is equal to true. - static LogicalResult createAvgPoolValueCountIncludePadTrueCase( + // Creates the average pooling operation value with a + // regular divisor; i.e., the product of kernel dimensions. + static LogicalResult createAveragePoolValueWithRegularDivisor( OpTy op, typename OpTy::Adaptor &adaptor, ConversionPatternRewriter &rewriter, Value self, Value sumPool, Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &kernelDimSizes, SmallVector &indexingMapsAvg, SmallVector &iteratorTypesAvg); }; @@ -1040,27 +1055,64 @@ LogicalResult ConvertAtenAvgPoolOp::matchAndRewrite( SmallVector iteratorTypesAvg( Dim + 2, utils::IteratorType::parallel); - auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase( - countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor, - resultType, kernelSizeIntValues, strideInts, paddingInts, indexingMapsAvg, - iteratorTypesAvg); - if (divisorOpResult) - return *divisorOpResult; + if (doesAvgPoolDivisorNeedsClamping(ceilMode, countIncludePad, strideInts, + paddingInts)) { + return createAveragePoolValueWithClampedDivisor( + ceilMode, countIncludePad, op, adaptor, rewriter, self, sumPool, + outputTensor, resultType, kernelSizeIntValues, strideInts, paddingInts, + indexingMapsAvg, iteratorTypesAvg); + } - return createAvgPoolValueCountIncludePadTrueCase( + return createAveragePoolValueWithRegularDivisor( op, adaptor, rewriter, self, sumPool, outputTensor, resultType, kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg); +} - return success(); +template +bool ConvertAtenAvgPoolOp:: + doesAvgPoolDivisorNeedsClamping(bool ceilMode, bool countIncludePad, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts) { + // Determines whether the average pooling divisor needs to be clamped + // (i.e., adjusted to exclude padded or out-of-bounds elements). + // + // There are two primary cases where clamping is needed: + // 1. Padding with count_include_pad == false: + // - If padding is applied (padding != 0) and count_include_pad is false, + // then padding elements are *excluded* from the divisor, effectively + // clamping the divisor to the number of valid input elements. + // + // 2. Ceil mode with non-unit stride: + // - When ceil_mode is enabled, output dimensions are rounded up, + // potentially + // creating pooling windows that extend beyond the input tensor bounds. + // PyTorch handles this by implicitly adding zero-padding outside the + // tensor, but these extra (implicit) padded elements are *not* included + // in the divisor. This behavior is independent of the count_include_pad + // flag. + // - If all strides are 1, ceil_mode will not produce fractional divisions, + // so the windows will not extend beyond bounds, and no clamping occurs. + // + // Reference: PyTorch AvgPool2d documentation and formula for H_out/W_out: + // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html + // + // See torch.nn.AvgPool2d E2E tests for comprehensive coverage. + + bool hasPadding = + !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; }); + bool allStridesUnitary = + llvm::all_of(strideInts, [](int64_t s) { return s == 1; }); + + return (!countIncludePad && hasPadding) || (ceilMode && !allStridesUnitary); } template -std::optional ConvertAtenAvgPoolOp:: - createAvgPoolValueCountIncludePadFalseCase( - bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter, Value self, Value sumPool, - Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, +LogicalResult ConvertAtenAvgPoolOp:: + createAveragePoolValueWithClampedDivisor( + bool ceilMode, bool countIncludePad, OpTy op, + typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter, + Value self, Value sumPool, Value outputTensor, Type resultType, + SmallVectorImpl &kernelDimSizes, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, SmallVector &indexingMapsAvg, @@ -1069,11 +1121,6 @@ std::optional ConvertAtenAvgPoolOp:: constexpr int avgPoolDims = getAvgPoolNumOfDims(); - bool noPadding = llvm::all_of(paddingInts, [](int64_t p) { return p == 0; }); - if (countIncludePad || noPadding) { - // These cases are not handled here. - return std::nullopt; - } if (avgPoolDims < 1) { return rewriter.notifyMatchFailure( op, "Unexpected type. Only expected AtenAvgPool1dOp, AtenAvgPool2dOp, " @@ -1082,8 +1129,8 @@ std::optional ConvertAtenAvgPoolOp:: Type resultElementType = cast(resultType).getElementType(); - PoolSizeCalculator poolSizeCalculator(self, sumPool, rewriter, - loc); + PoolSizeCalculator poolSizeCalculator( + self, sumPool, countIncludePad, rewriter, loc); // AtenAvgPool2/3dOp has an optional divisor_override // attribute while AtenAvgPool1dOp does not. @@ -1104,7 +1151,7 @@ std::optional ConvertAtenAvgPoolOp:: [&](OpBuilder &b, Location loc, ValueRange args) { if (!poolSize) { poolSize = poolSizeCalculator.getPoolSize( - b, kernelSizeIntValues, strideInts, paddingInts); + b, kernelDimSizes, strideInts, paddingInts); } Value divisor = convertScalarToDtype(b, loc, poolSize, resultElementType); @@ -1122,21 +1169,21 @@ std::optional ConvertAtenAvgPoolOp:: template LogicalResult ConvertAtenAvgPoolOp:: - createAvgPoolValueCountIncludePadTrueCase( + createAveragePoolValueWithRegularDivisor( OpTy op, typename OpTy::Adaptor &adaptor, ConversionPatternRewriter &rewriter, Value self, Value sumPool, Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &kernelDimSizes, SmallVector &indexingMapsAvg, SmallVector &iteratorTypesAvg) { Location loc = op->getLoc(); Type resultElementType = cast(resultType).getElementType(); - Value divisor = kernelSizeIntValues[0]; - for (uint32_t i = 1; i < kernelSizeIntValues.size(); ++i) { - divisor = rewriter.createOrFold(loc, divisor, - kernelSizeIntValues[i]); + Value divisor = kernelDimSizes[0]; + for (uint32_t i = 1; i < kernelDimSizes.size(); ++i) { + divisor = + rewriter.createOrFold(loc, divisor, kernelDimSizes[i]); } // Only average pooling 2D/3D have optional divisor override. if constexpr (!std::is_same()) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 42d7e01f9468..4e87574c6df8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -650,6 +650,12 @@ "Aten_EmbeddingBagExample_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", + "AvgPool2dCeilPadNonUnitaryStrides_basic", + "AvgPool2dCeilNoPadStridedIncludePadding_basic", + "AvgPool2dCeilPaddingStridedIncludePadding_basic", + "AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic", "AvgPool2dDivisorOverrideModule_basic", "BernoulliTensorModule_basic", "BincountMinlengthModule_basic", @@ -2791,6 +2797,10 @@ "AvgPool2dSingleIntTupleParamsIncludePadModule_basic", "AvgPool2dSingleIntTupleParamsModule_basic", "AvgPool2dWithoutPadModule_basic", + "AvgPool1dNoPadCeilPadNotIncluded_basic", + "AvgPool1dPadCeilPadNotIncluded_basic", + "AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic", "BatchMlpLayerModule_basic", "BincountMinlengthModule_basic", "BincountModule_basic", @@ -3533,6 +3543,13 @@ "AvgPool1dIntModule_basic", "AvgPool1dStaticModule_basic", "AvgPool2dCeilModeTrueModule_basic", + "AvgPool1dNoPadCeilPadNotIncluded_basic", + "AvgPool1dPadCeilPadNotIncluded_basic", + "AvgPool2dCeilPaddingStridedIncludePadding_basic", + "AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic", + "AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic", + "AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic", + "AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic", "AvgPool2dDivisorOverrideModule_basic", "AvgPool2dFloatModule_basic", "AvgPool2dIntModule_basic", @@ -3939,6 +3956,8 @@ "AtenKthvalueFloat64Module_basic", "AtenKthvalueKeepDimModule_basic", "AtenKthvalueModule_basic", + "AvgPool2dCeilNoPadUnitaryStrides_basic", + "AvgPool2dCeilPadNonUnitaryStrides_basic", "AvgPool2dCountIncludePadFalseStaticModule_basic", "AvgPool3dStaticModule_basic", "Conv_Transpose1dModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 4a43b99033c1..9ef3cffb2193 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -2514,3 +2514,469 @@ def MaxUnpool3dModulePad0_basic(module, tu: TestUtils): output, indices = pool(input) module.forward(output, indices) + + +class AvgPool2dCeilNoPadUnitaryStrides(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilNoPadUnitaryStrides()) +def AvgPool2dCeilNoPadUnitaryStrides_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPadNonUnitaryStrides(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[1, 1], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilPadNonUnitaryStrides()) +def AvgPool2dCeilPadNonUnitaryStrides_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilNoPadStridedIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[0, 0], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilNoPadStridedIncludePadding()) +def AvgPool2dCeilNoPadStridedIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilNoPadUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[0, 0], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dCeilNoPadUnitaryStrideIncludePadding() +) +def AvgPool2dCeilNoPadUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPaddingUnitaryStrideIncludePaddingFalse(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dCeilPaddingUnitaryStrideIncludePaddingFalse() +) +def AvgPool2dCeilPaddingUnitaryStrideIncludePaddingFalse_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dFloorNoPadUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[0, 0], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dFloorNoPadUnitaryStrideIncludePadding() +) +def AvgPool2dFloorNoPadUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dFloorPaddingUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dFloorPaddingUnitaryStrideIncludePadding() +) +def AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPaddingUnitaryStrideIncludePadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dCeilPaddingUnitaryStrideIncludePadding() +) +def AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dCeilPaddingStridedIncludePadding(torch.nn.Module): + # Note that in this case the kernel window center will go into the padding. + # When this happens the padding elements are counted in the divisor, but + # the out of bound elements from the ceiling are not counted + # (i.e., clamped from the divisor count). + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[2, 2], + padding=[1, 1], + ceil_mode=True, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCeilPaddingStridedIncludePadding()) +def AvgPool2dCeilPaddingStridedIncludePadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 4, low=-1)) + + +class AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded(torch.nn.Module): + # This test captures the torch-mlir issue reported here: + # https://github.com/llvm/torch-mlir/issues/4079 + # The issue was caused by having the ceil_mode = true and + # count_include_pad = false. Also the kernel and stride sizes are + # different in this test to make sure that they are processed in + # the right order. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 2], + stride=[2, 3], + padding=[0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded() +) +def AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 4, low=-1)) + + +class AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded(torch.nn.Module): + # Different sizes used for each kernel, stride, and padding.dimensions. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 4], + stride=[2, 3], + padding=[1, 2], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded() +) +def AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 4, low=-1)) + + +class AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded(torch.nn.Module): + # 3D version of AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool3d( + kernel_size=[3, 2, 4], + stride=[3, 2, 5], + padding=[0, 0, 0], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 4, 5, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded() +) +def AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 5, 7, low=-1)) + + +class AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded(torch.nn.Module): + # 3-D version of AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool3d( + kernel_size=[3, 4, 7], + stride=[2, 3, 4], + padding=[1, 2, 3], + ceil_mode=True, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3, 4, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case( + module_factory=lambda: AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded() +) +def AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 4, 7, low=-1)) + + +class AvgPool1dNoPadCeilPadNotIncluded(torch.nn.Module): + # 1D version of AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool1d( + kernel_size=[2], + stride=[2], + padding=[1], + ceil_mode=True, + count_include_pad=False, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 5], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool1dNoPadCeilPadNotIncluded()) +def AvgPool1dNoPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 5, low=-1)) + + +class AvgPool1dPadCeilPadNotIncluded(torch.nn.Module): + # 1-D version of AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded. + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool1d( + kernel_size=[2], + stride=[2], + padding=[1], + ceil_mode=True, + count_include_pad=False, + ) + + @export + @annotate_args( + [ + None, + ([1, 1, 3], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool1dPadCeilPadNotIncluded()) +def AvgPool1dPadCeilPadNotIncluded_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, low=-1)) diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index c065e624efa9..91043b83728a 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -126,7 +126,7 @@ func.func @forward_avg_pool2d_countincludepad_false(%arg0: !torch.vtensor<[1,3,6 // CHECK: linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<[1, 2]> : vector<2xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x3x64x58xf32>, tensor<4x5xf32>) outs(%[[OUT1:.*]] : tensor<1x3x61x27xf32>) -> tensor<1x3x61x27xf32> // CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x61x27xf32>) outs(%[[OUT2:.*]] : tensor<1x3x61x27xf32>) // CHECK-NEXT: ^bb0(%[[BIIN1:.*]]: f32, %[[BOUT1:.*]]: f32): - // CHECK-COUNT-4: arith.minsi + // CHECK-COUNT-1: arith.minsi // CHECK-COUNT-1: arith.divf // CHECK: linalg.yield %[[TMP1:.*]] : f32 // CHECK-NEXT: } -> tensor<1x3x61x27xf32> @@ -179,7 +179,7 @@ func.func @forward_avg_pool3dd_countincludepad_false(%arg0: !torch.vtensor<[1,3, // CHECK: linalg.pooling_ndhwc_sum {dilations = dense<1> : vector<3xi64>, strides = dense<[1, 2, 1]> : vector<3xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x7x66x58x3xf32>, tensor<4x5x5xf32>) outs(%[[OUT1:.*]] : tensor<1x4x31x54x3xf32>) -> tensor<1x4x31x54x3xf32> // CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x4x31x54xf32>) outs(%[[OUT2:.*]] : tensor<1x3x4x31x54xf32>) // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): - // CHECK-COUNT-6: arith.minsi + // CHECK-COUNT-3: arith.minsi // CHECK-COUNT-1: arith.divf // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 // CHECK-NEXT: } -> tensor<1x3x4x31x54xf32> @@ -221,7 +221,7 @@ func.func @forward_avg_pool1d_countincludepad_false(%arg0: !torch.vtensor<[1,512 // CHECK: linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%[[IN1:.*]], %[[IN2:.*]] : tensor<1x512x12xf32>, tensor<1xf32>) outs(%[[OUT1:.*]] : tensor<1x512x12xf32>) -> tensor<1x512x12xf32> // CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[IN3:.*]] : tensor<1x512x12xf32>) outs(%[[OUT2:.*]] : tensor<1x512x12xf32> // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): - // CHECK-COUNT-2: arith.minsi + // CHECK-COUNT-1: arith.minsi // CHECK-COUNT-1: arith.divf // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 // CHECK-NEXT: } -> tensor<1x512x12xf32> @@ -233,3 +233,31 @@ func.func @forward_avg_pool1d_countincludepad_false(%arg0: !torch.vtensor<[1,512 %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %false : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,12],f32> return %3 : !torch.vtensor<[1,512,12],f32> } + +// CHECK-LABEL: func @forward_avgpool_2d_ceil +func.func @forward_avgpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> { + // CHECK: %[[POOL_OUT:.*]] = linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%[[PADDED_IN:.*]], %[[KERNEL_IN:.*]] : tensor<1x1x6x6xf32>, tensor<3x3xf32>) outs(%[[OUT1:.*]] : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> + // CHECK: linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL_OUT]] : tensor<1x1x2x2xf32>) outs(%[[GEN_OUT:.*]] : tensor<1x1x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-COUNT-3: arith.muli + // CHECK-COUNT-1: arith.sitofp + // CHECK-COUNT-1: arith.divf + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x1x2x2xf32> + %int3 = torch.constant.int 3 + %int3_0 = torch.constant.int 3 + %int0 = torch.constant.int 0 + %int0_1 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int2_2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int1_3 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int3, %int3_0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int0, %int0_1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int2, %int2_2, %int1, %int1_3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %true = torch.constant.bool true + %false = torch.constant.bool false + %none = torch.constant.none + %3 = torch.aten.avg_pool2d %arg0, %0, %2, %1, %true, %false, %none : !torch.vtensor<[1,1,4,4],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,2,2],f32> + return %3 : !torch.vtensor<[1,1,2,2],f32> +}