Skip to content

Average pooling clamped divisor should be done on all conditions where the kernel can go out of bounds #4144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 105 additions & 58 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ namespace {
// used in the divisor of the average pooling operator.
template <int NumOfDims> class PoolSizeCalculator {
public:
PoolSizeCalculator(Value self, Value sumPool,
PoolSizeCalculator(Value self, Value sumPool, bool countIncludePad,
ConversionPatternRewriter &rewriter, Location loc);

// The algorithm for computing the divisor with
Expand All @@ -871,36 +871,37 @@ template <int NumOfDims> class PoolSizeCalculator {
SmallVectorImpl<int64_t> &paddingInts);

private:
int64_t DimSizeFromSumPoolType[NumOfDims];
Value InputSpatialDimValues[NumOfDims];
int64_t SumPoolTypeDimIndex[NumOfDims];
Value InputSpatialDimSizes[NumOfDims];
Location location;
bool isCountIncludePad;
};

} // namespace

template <int NumOfDims>
PoolSizeCalculator<NumOfDims>::PoolSizeCalculator(
Value self, Value sumPool, ConversionPatternRewriter &rewriter,
Location loc)
: location(loc) {
Value self, Value sumPool, bool countIncludePad,
ConversionPatternRewriter &rewriter, Location loc)
: location(loc), isCountIncludePad(countIncludePad) {
auto selfType = cast<RankedTensorType>(self.getType());
const int64_t selfRank = selfType.getRank();
RankedTensorType sumPoolType = cast<RankedTensorType>(sumPool.getType());
const int64_t rank = sumPoolType.getRank();

// Store dimensions in this order:
// 0 => width, 1 => height, 2 => depth
// 0 => depth, 1 => height, 2 => width
for (int i = 0; i < NumOfDims; ++i) {
int64_t DimSizeFromSelfType = toPositiveDim(-(i + 1), selfRank);
InputSpatialDimValues[i] =
getDimOp(rewriter, location, self, DimSizeFromSelfType);
DimSizeFromSumPoolType[i] = toPositiveDim(-(i + 1), rank);
int64_t inputSpatialDimIndex = toPositiveDim(-(i + 1), selfRank);
InputSpatialDimSizes[NumOfDims - i - 1] =
getDimOp(rewriter, location, self, inputSpatialDimIndex);
SumPoolTypeDimIndex[NumOfDims - i - 1] = toPositiveDim(-(i + 1), rank);
}
}

template <int NumOfDims>
Value PoolSizeCalculator<NumOfDims>::getPoolSize(
OpBuilder &b, SmallVectorImpl<Value> &kernelSizeIntValues,
OpBuilder &b, SmallVectorImpl<Value> &kernelDimSizes,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts) {
Value poolSize;
Expand All @@ -915,19 +916,20 @@ Value PoolSizeCalculator<NumOfDims>::getPoolSize(
// Dim below stands for spatial dimension. Prior to the February 2025
// change, these variables used "height" and "width" (or "h" and "w")
// in these intermediate variables instead of "Dim".

Value IndexODim =
b.create<linalg::IndexOp>(location,
/*value=*/DimSizeFromSumPoolType[i]);
/*value=*/SumPoolTypeDimIndex[i]);
Value ODim = castIndexToInt64(b, location, IndexODim);
Value DDim = b.createOrFold<arith::ConstantOp>(
location, b.getI64IntegerAttr(strideInts[i]));
Value PadDim = b.createOrFold<arith::ConstantOp>(
location, b.getI64IntegerAttr(paddingInts[i]));
Value ODimDDim = b.createOrFold<arith::MulIOp>(location, ODim, DDim);
Value IDim0 = b.createOrFold<arith::SubIOp>(location, ODimDDim, PadDim);
Value IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]);
Value IDim = castIndexToInt64(b, location, InputSpatialDimSizes[i]);
Value IDim0KDim =
b.createOrFold<arith::AddIOp>(location, IDim0, kernelSizeIntValues[i]);
b.createOrFold<arith::AddIOp>(location, IDim0, kernelDimSizes[i]);
Value IDimPadDim = b.createOrFold<arith::AddIOp>(location, IDim, PadDim);
Value IDim1 =
b.createOrFold<arith::MinSIOp>(location, IDim0KDim, IDimPadDim);
Expand All @@ -937,11 +939,15 @@ Value PoolSizeCalculator<NumOfDims>::getPoolSize(
Value IDim1Clamped = b.createOrFold<arith::MinSIOp>(location, IDim1, IDim);
Value IDim1_IDim0_Clamped =
b.createOrFold<arith::SubIOp>(location, IDim1Clamped, IDim0Clamped);

Value poolSizeDim =
!isCountIncludePad
? IDim1_IDim0_Clamped
: b.createOrFold<arith::SubIOp>(location, IDim1, IDim0);
if (i == 0) {
poolSize = IDim1_IDim0_Clamped;
poolSize = poolSizeDim;
} else {
poolSize = b.createOrFold<arith::MulIOp>(location, poolSize,
IDim1_IDim0_Clamped);
poolSize = b.createOrFold<arith::MulIOp>(location, poolSize, poolSizeDim);
}
}
return poolSize;
Expand All @@ -957,26 +963,35 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override;

// Creates the average pooling operation value when the
// count_include_pad parameter is equal to false.
static std::optional<LogicalResult>
createAvgPoolValueCountIncludePadFalseCase(
bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelSizeIntValues,
// If the condition below is true, the divisor total must subtract the
// elements not counted (clamped divisor count). If false, the divisor
// is just the product of kernel dimensions.
static bool
doesAvgPoolDivisorNeedsClamping(bool ceilMode, bool countIncludePad,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts);

// Creates the average pooling operation value with a clamped
// divisor. The clamped divisor is the product of kernel
// dimensions minus the elements not counted; e.g., padding
// and ceiling mode implicit padding.
static LogicalResult createAveragePoolValueWithClampedDivisor(
bool ceilMode, bool countIncludePad, OpTy op,
typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter,
Value self, Value sumPool, Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelDimSizes,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts,
SmallVector<AffineMap> &indexingMapsAvg,
SmallVector<utils::IteratorType> &iteratorTypesAvg);

// Creates the average pooling operation value when the
// count_include_pad parameter is equal to true.
static LogicalResult createAvgPoolValueCountIncludePadTrueCase(
// Creates the average pooling operation value with a
// regular divisor; i.e., the product of kernel dimensions.
static LogicalResult createAveragePoolValueWithRegularDivisor(
OpTy op, typename OpTy::Adaptor &adaptor,
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<Value> &kernelDimSizes,
SmallVector<AffineMap> &indexingMapsAvg,
SmallVector<utils::IteratorType> &iteratorTypesAvg);
};
Expand Down Expand Up @@ -1040,27 +1055,64 @@ LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::matchAndRewrite(
SmallVector<utils::IteratorType> iteratorTypesAvg(
Dim + 2, utils::IteratorType::parallel);

auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase(
countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor,
resultType, kernelSizeIntValues, strideInts, paddingInts, indexingMapsAvg,
iteratorTypesAvg);
if (divisorOpResult)
return *divisorOpResult;
if (doesAvgPoolDivisorNeedsClamping(ceilMode, countIncludePad, strideInts,
paddingInts)) {
return createAveragePoolValueWithClampedDivisor(
ceilMode, countIncludePad, op, adaptor, rewriter, self, sumPool,
outputTensor, resultType, kernelSizeIntValues, strideInts, paddingInts,
indexingMapsAvg, iteratorTypesAvg);
}

return createAvgPoolValueCountIncludePadTrueCase(
return createAveragePoolValueWithRegularDivisor(
op, adaptor, rewriter, self, sumPool, outputTensor, resultType,
kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg);
}

return success();
template <typename OpTy, typename PoolingOpTy, int Dim>
bool ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
doesAvgPoolDivisorNeedsClamping(bool ceilMode, bool countIncludePad,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts) {
// Determines whether the average pooling divisor needs to be clamped
// (i.e., adjusted to exclude padded or out-of-bounds elements).
//
// There are two primary cases where clamping is needed:
// 1. Padding with count_include_pad == false:
// - If padding is applied (padding != 0) and count_include_pad is false,
// then padding elements are *excluded* from the divisor, effectively
// clamping the divisor to the number of valid input elements.
//
// 2. Ceil mode with non-unit stride:
// - When ceil_mode is enabled, output dimensions are rounded up,
// potentially
// creating pooling windows that extend beyond the input tensor bounds.
// PyTorch handles this by implicitly adding zero-padding outside the
// tensor, but these extra (implicit) padded elements are *not* included
// in the divisor. This behavior is independent of the count_include_pad
// flag.
// - If all strides are 1, ceil_mode will not produce fractional divisions,
// so the windows will not extend beyond bounds, and no clamping occurs.
//
// Reference: PyTorch AvgPool2d documentation and formula for H_out/W_out:
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
//
// See torch.nn.AvgPool2d E2E tests for comprehensive coverage.

bool hasPadding =
!llvm::all_of(paddingInts, [](int64_t p) { return p == 0; });
bool allStridesUnitary =
llvm::all_of(strideInts, [](int64_t s) { return s == 1; });

return (!countIncludePad && hasPadding) || (ceilMode && !allStridesUnitary);
}

template <typename OpTy, typename PoolingOpTy, int Dim>
std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
createAvgPoolValueCountIncludePadFalseCase(
bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelSizeIntValues,
LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
createAveragePoolValueWithClampedDivisor(
bool ceilMode, bool countIncludePad, OpTy op,
typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter,
Value self, Value sumPool, Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelDimSizes,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts,
SmallVector<AffineMap> &indexingMapsAvg,
Expand All @@ -1069,11 +1121,6 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::

constexpr int avgPoolDims = getAvgPoolNumOfDims<OpTy>();

bool noPadding = llvm::all_of(paddingInts, [](int64_t p) { return p == 0; });
if (countIncludePad || noPadding) {
// These cases are not handled here.
return std::nullopt;
}
if (avgPoolDims < 1) {
return rewriter.notifyMatchFailure(
op, "Unexpected type. Only expected AtenAvgPool1dOp, AtenAvgPool2dOp, "
Expand All @@ -1082,8 +1129,8 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::

Type resultElementType = cast<RankedTensorType>(resultType).getElementType();

PoolSizeCalculator<avgPoolDims> poolSizeCalculator(self, sumPool, rewriter,
loc);
PoolSizeCalculator<avgPoolDims> poolSizeCalculator(
self, sumPool, countIncludePad, rewriter, loc);

// AtenAvgPool2/3dOp has an optional divisor_override
// attribute while AtenAvgPool1dOp does not.
Expand All @@ -1104,7 +1151,7 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
[&](OpBuilder &b, Location loc, ValueRange args) {
if (!poolSize) {
poolSize = poolSizeCalculator.getPoolSize(
b, kernelSizeIntValues, strideInts, paddingInts);
b, kernelDimSizes, strideInts, paddingInts);
}
Value divisor =
convertScalarToDtype(b, loc, poolSize, resultElementType);
Expand All @@ -1122,21 +1169,21 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::

template <typename OpTy, typename PoolingOpTy, int Dim>
LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
createAvgPoolValueCountIncludePadTrueCase(
createAveragePoolValueWithRegularDivisor(
OpTy op, typename OpTy::Adaptor &adaptor,
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
Value outputTensor, Type resultType,
SmallVectorImpl<Value> &kernelSizeIntValues,
SmallVectorImpl<Value> &kernelDimSizes,
SmallVector<AffineMap> &indexingMapsAvg,
SmallVector<utils::IteratorType> &iteratorTypesAvg) {
Location loc = op->getLoc();

Type resultElementType = cast<RankedTensorType>(resultType).getElementType();

Value divisor = kernelSizeIntValues[0];
for (uint32_t i = 1; i < kernelSizeIntValues.size(); ++i) {
divisor = rewriter.createOrFold<arith::MulIOp>(loc, divisor,
kernelSizeIntValues[i]);
Value divisor = kernelDimSizes[0];
for (uint32_t i = 1; i < kernelDimSizes.size(); ++i) {
divisor =
rewriter.createOrFold<arith::MulIOp>(loc, divisor, kernelDimSizes[i]);
}
// Only average pooling 2D/3D have optional divisor override.
if constexpr (!std::is_same<OpTy, AtenAvgPool1dOp>()) {
Expand Down
19 changes: 19 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,12 @@
"Aten_EmbeddingBagExample_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"AvgPool2dCeilPadNonUnitaryStrides_basic",
"AvgPool2dCeilNoPadStridedIncludePadding_basic",
"AvgPool2dCeilPaddingStridedIncludePadding_basic",
"AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
"AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
"AvgPool2dDivisorOverrideModule_basic",
"BernoulliTensorModule_basic",
"BincountMinlengthModule_basic",
Expand Down Expand Up @@ -2791,6 +2797,10 @@
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
"AvgPool2dSingleIntTupleParamsModule_basic",
"AvgPool2dWithoutPadModule_basic",
"AvgPool1dNoPadCeilPadNotIncluded_basic",
"AvgPool1dPadCeilPadNotIncluded_basic",
"AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic",
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
"BatchMlpLayerModule_basic",
"BincountMinlengthModule_basic",
"BincountModule_basic",
Expand Down Expand Up @@ -3533,6 +3543,13 @@
"AvgPool1dIntModule_basic",
"AvgPool1dStaticModule_basic",
"AvgPool2dCeilModeTrueModule_basic",
"AvgPool1dNoPadCeilPadNotIncluded_basic",
"AvgPool1dPadCeilPadNotIncluded_basic",
"AvgPool2dCeilPaddingStridedIncludePadding_basic",
"AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic",
"AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic",
"AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
"AvgPool2dDivisorOverrideModule_basic",
"AvgPool2dFloatModule_basic",
"AvgPool2dIntModule_basic",
Expand Down Expand Up @@ -3939,6 +3956,8 @@
"AtenKthvalueFloat64Module_basic",
"AtenKthvalueKeepDimModule_basic",
"AtenKthvalueModule_basic",
"AvgPool2dCeilNoPadUnitaryStrides_basic",
"AvgPool2dCeilPadNonUnitaryStrides_basic",
"AvgPool2dCountIncludePadFalseStaticModule_basic",
"AvgPool3dStaticModule_basic",
"Conv_Transpose1dModule_basic",
Expand Down
Loading
Loading