Skip to content

Commit 4b206d7

Browse files
ivangarcia44Ivan Garcia
and
Ivan Garcia
authored
Average pooling clamped divisor should be done on all conditions where the kernel can go out of bounds (#4144)
In this pull request I added various E2E tests to cover edge cases on the average pooling torch to linalg lowering algorithm. These new tests uncovered various numerical issues that were addressed in this same PR. One of the issues is the IREE test failure found in #4079 which triggered this work. Background: In the most common case, the divisor for the average pooling is just the product of kernel dimensions. But with padding, and ceil mode options, some elements need to be discounted from the divisor computation. This change fixes two components of this: Fix the condition that determines if the divisor is just the product of kernel dimensions or the clamped divisor comptutation. Add missing isCountIncludePad logic divisor computation algorithm and reversal of kernel/stride/padding parameters element order. Both were missing from the first generalization change. --------- Co-authored-by: Ivan Garcia <[email protected]> Co-authored-by: ivangarcia44 <ivangarcia44>
1 parent 4e2d0fd commit 4b206d7

File tree

4 files changed

+621
-61
lines changed

4 files changed

+621
-61
lines changed

lib/Conversion/TorchToLinalg/Pooling.cpp

Lines changed: 105 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ namespace {
862862
// used in the divisor of the average pooling operator.
863863
template <int NumOfDims> class PoolSizeCalculator {
864864
public:
865-
PoolSizeCalculator(Value self, Value sumPool,
865+
PoolSizeCalculator(Value self, Value sumPool, bool countIncludePad,
866866
ConversionPatternRewriter &rewriter, Location loc);
867867

868868
// The algorithm for computing the divisor with
@@ -877,36 +877,37 @@ template <int NumOfDims> class PoolSizeCalculator {
877877
SmallVectorImpl<int64_t> &paddingInts);
878878

879879
private:
880-
int64_t DimSizeFromSumPoolType[NumOfDims];
881-
Value InputSpatialDimValues[NumOfDims];
880+
int64_t SumPoolTypeDimIndex[NumOfDims];
881+
Value InputSpatialDimSizes[NumOfDims];
882882
Location location;
883+
bool isCountIncludePad;
883884
};
884885

885886
} // namespace
886887

887888
template <int NumOfDims>
888889
PoolSizeCalculator<NumOfDims>::PoolSizeCalculator(
889-
Value self, Value sumPool, ConversionPatternRewriter &rewriter,
890-
Location loc)
891-
: location(loc) {
890+
Value self, Value sumPool, bool countIncludePad,
891+
ConversionPatternRewriter &rewriter, Location loc)
892+
: location(loc), isCountIncludePad(countIncludePad) {
892893
auto selfType = cast<RankedTensorType>(self.getType());
893894
const int64_t selfRank = selfType.getRank();
894895
RankedTensorType sumPoolType = cast<RankedTensorType>(sumPool.getType());
895896
const int64_t rank = sumPoolType.getRank();
896897

897898
// Store dimensions in this order:
898-
// 0 => width, 1 => height, 2 => depth
899+
// 0 => depth, 1 => height, 2 => width
899900
for (int i = 0; i < NumOfDims; ++i) {
900-
int64_t DimSizeFromSelfType = toPositiveDim(-(i + 1), selfRank);
901-
InputSpatialDimValues[i] =
902-
getDimOp(rewriter, location, self, DimSizeFromSelfType);
903-
DimSizeFromSumPoolType[i] = toPositiveDim(-(i + 1), rank);
901+
int64_t inputSpatialDimIndex = toPositiveDim(-(i + 1), selfRank);
902+
InputSpatialDimSizes[NumOfDims - i - 1] =
903+
getDimOp(rewriter, location, self, inputSpatialDimIndex);
904+
SumPoolTypeDimIndex[NumOfDims - i - 1] = toPositiveDim(-(i + 1), rank);
904905
}
905906
}
906907

907908
template <int NumOfDims>
908909
Value PoolSizeCalculator<NumOfDims>::getPoolSize(
909-
OpBuilder &b, SmallVectorImpl<Value> &kernelSizeIntValues,
910+
OpBuilder &b, SmallVectorImpl<Value> &kernelDimSizes,
910911
SmallVectorImpl<int64_t> &strideInts,
911912
SmallVectorImpl<int64_t> &paddingInts) {
912913
Value poolSize;
@@ -921,19 +922,20 @@ Value PoolSizeCalculator<NumOfDims>::getPoolSize(
921922
// Dim below stands for spatial dimension. Prior to the February 2025
922923
// change, these variables used "height" and "width" (or "h" and "w")
923924
// in these intermediate variables instead of "Dim".
925+
924926
Value IndexODim =
925927
b.create<linalg::IndexOp>(location,
926-
/*value=*/DimSizeFromSumPoolType[i]);
928+
/*value=*/SumPoolTypeDimIndex[i]);
927929
Value ODim = castIndexToInt64(b, location, IndexODim);
928930
Value DDim = b.createOrFold<arith::ConstantOp>(
929931
location, b.getI64IntegerAttr(strideInts[i]));
930932
Value PadDim = b.createOrFold<arith::ConstantOp>(
931933
location, b.getI64IntegerAttr(paddingInts[i]));
932934
Value ODimDDim = b.createOrFold<arith::MulIOp>(location, ODim, DDim);
933935
Value IDim0 = b.createOrFold<arith::SubIOp>(location, ODimDDim, PadDim);
934-
Value IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]);
936+
Value IDim = castIndexToInt64(b, location, InputSpatialDimSizes[i]);
935937
Value IDim0KDim =
936-
b.createOrFold<arith::AddIOp>(location, IDim0, kernelSizeIntValues[i]);
938+
b.createOrFold<arith::AddIOp>(location, IDim0, kernelDimSizes[i]);
937939
Value IDimPadDim = b.createOrFold<arith::AddIOp>(location, IDim, PadDim);
938940
Value IDim1 =
939941
b.createOrFold<arith::MinSIOp>(location, IDim0KDim, IDimPadDim);
@@ -943,11 +945,15 @@ Value PoolSizeCalculator<NumOfDims>::getPoolSize(
943945
Value IDim1Clamped = b.createOrFold<arith::MinSIOp>(location, IDim1, IDim);
944946
Value IDim1_IDim0_Clamped =
945947
b.createOrFold<arith::SubIOp>(location, IDim1Clamped, IDim0Clamped);
948+
949+
Value poolSizeDim =
950+
!isCountIncludePad
951+
? IDim1_IDim0_Clamped
952+
: b.createOrFold<arith::SubIOp>(location, IDim1, IDim0);
946953
if (i == 0) {
947-
poolSize = IDim1_IDim0_Clamped;
954+
poolSize = poolSizeDim;
948955
} else {
949-
poolSize = b.createOrFold<arith::MulIOp>(location, poolSize,
950-
IDim1_IDim0_Clamped);
956+
poolSize = b.createOrFold<arith::MulIOp>(location, poolSize, poolSizeDim);
951957
}
952958
}
953959
return poolSize;
@@ -963,26 +969,35 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
963969
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
964970
ConversionPatternRewriter &rewriter) const override;
965971

966-
// Creates the average pooling operation value when the
967-
// count_include_pad parameter is equal to false.
968-
static std::optional<LogicalResult>
969-
createAvgPoolValueCountIncludePadFalseCase(
970-
bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor,
971-
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
972-
Value outputTensor, Type resultType,
973-
SmallVectorImpl<Value> &kernelSizeIntValues,
972+
// If the condition below is true, the divisor total must subtract the
973+
// elements not counted (clamped divisor count). If false, the divisor
974+
// is just the product of kernel dimensions.
975+
static bool
976+
doesAvgPoolDivisorNeedsClamping(bool ceilMode, bool countIncludePad,
977+
SmallVectorImpl<int64_t> &strideInts,
978+
SmallVectorImpl<int64_t> &paddingInts);
979+
980+
// Creates the average pooling operation value with a clamped
981+
// divisor. The clamped divisor is the product of kernel
982+
// dimensions minus the elements not counted; e.g., padding
983+
// and ceiling mode implicit padding.
984+
static LogicalResult createAveragePoolValueWithClampedDivisor(
985+
bool ceilMode, bool countIncludePad, OpTy op,
986+
typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter,
987+
Value self, Value sumPool, Value outputTensor, Type resultType,
988+
SmallVectorImpl<Value> &kernelDimSizes,
974989
SmallVectorImpl<int64_t> &strideInts,
975990
SmallVectorImpl<int64_t> &paddingInts,
976991
SmallVector<AffineMap> &indexingMapsAvg,
977992
SmallVector<utils::IteratorType> &iteratorTypesAvg);
978993

979-
// Creates the average pooling operation value when the
980-
// count_include_pad parameter is equal to true.
981-
static LogicalResult createAvgPoolValueCountIncludePadTrueCase(
994+
// Creates the average pooling operation value with a
995+
// regular divisor; i.e., the product of kernel dimensions.
996+
static LogicalResult createAveragePoolValueWithRegularDivisor(
982997
OpTy op, typename OpTy::Adaptor &adaptor,
983998
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
984999
Value outputTensor, Type resultType,
985-
SmallVectorImpl<Value> &kernelSizeIntValues,
1000+
SmallVectorImpl<Value> &kernelDimSizes,
9861001
SmallVector<AffineMap> &indexingMapsAvg,
9871002
SmallVector<utils::IteratorType> &iteratorTypesAvg);
9881003
};
@@ -1046,27 +1061,64 @@ LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::matchAndRewrite(
10461061
SmallVector<utils::IteratorType> iteratorTypesAvg(
10471062
Dim + 2, utils::IteratorType::parallel);
10481063

1049-
auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase(
1050-
countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor,
1051-
resultType, kernelSizeIntValues, strideInts, paddingInts, indexingMapsAvg,
1052-
iteratorTypesAvg);
1053-
if (divisorOpResult)
1054-
return *divisorOpResult;
1064+
if (doesAvgPoolDivisorNeedsClamping(ceilMode, countIncludePad, strideInts,
1065+
paddingInts)) {
1066+
return createAveragePoolValueWithClampedDivisor(
1067+
ceilMode, countIncludePad, op, adaptor, rewriter, self, sumPool,
1068+
outputTensor, resultType, kernelSizeIntValues, strideInts, paddingInts,
1069+
indexingMapsAvg, iteratorTypesAvg);
1070+
}
10551071

1056-
return createAvgPoolValueCountIncludePadTrueCase(
1072+
return createAveragePoolValueWithRegularDivisor(
10571073
op, adaptor, rewriter, self, sumPool, outputTensor, resultType,
10581074
kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg);
1075+
}
10591076

1060-
return success();
1077+
template <typename OpTy, typename PoolingOpTy, int Dim>
1078+
bool ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1079+
doesAvgPoolDivisorNeedsClamping(bool ceilMode, bool countIncludePad,
1080+
SmallVectorImpl<int64_t> &strideInts,
1081+
SmallVectorImpl<int64_t> &paddingInts) {
1082+
// Determines whether the average pooling divisor needs to be clamped
1083+
// (i.e., adjusted to exclude padded or out-of-bounds elements).
1084+
//
1085+
// There are two primary cases where clamping is needed:
1086+
// 1. Padding with count_include_pad == false:
1087+
// - If padding is applied (padding != 0) and count_include_pad is false,
1088+
// then padding elements are *excluded* from the divisor, effectively
1089+
// clamping the divisor to the number of valid input elements.
1090+
//
1091+
// 2. Ceil mode with non-unit stride:
1092+
// - When ceil_mode is enabled, output dimensions are rounded up,
1093+
// potentially
1094+
// creating pooling windows that extend beyond the input tensor bounds.
1095+
// PyTorch handles this by implicitly adding zero-padding outside the
1096+
// tensor, but these extra (implicit) padded elements are *not* included
1097+
// in the divisor. This behavior is independent of the count_include_pad
1098+
// flag.
1099+
// - If all strides are 1, ceil_mode will not produce fractional divisions,
1100+
// so the windows will not extend beyond bounds, and no clamping occurs.
1101+
//
1102+
// Reference: PyTorch AvgPool2d documentation and formula for H_out/W_out:
1103+
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
1104+
//
1105+
// See torch.nn.AvgPool2d E2E tests for comprehensive coverage.
1106+
1107+
bool hasPadding =
1108+
!llvm::all_of(paddingInts, [](int64_t p) { return p == 0; });
1109+
bool allStridesUnitary =
1110+
llvm::all_of(strideInts, [](int64_t s) { return s == 1; });
1111+
1112+
return (!countIncludePad && hasPadding) || (ceilMode && !allStridesUnitary);
10611113
}
10621114

10631115
template <typename OpTy, typename PoolingOpTy, int Dim>
1064-
std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1065-
createAvgPoolValueCountIncludePadFalseCase(
1066-
bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor,
1067-
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
1068-
Value outputTensor, Type resultType,
1069-
SmallVectorImpl<Value> &kernelSizeIntValues,
1116+
LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1117+
createAveragePoolValueWithClampedDivisor(
1118+
bool ceilMode, bool countIncludePad, OpTy op,
1119+
typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter,
1120+
Value self, Value sumPool, Value outputTensor, Type resultType,
1121+
SmallVectorImpl<Value> &kernelDimSizes,
10701122
SmallVectorImpl<int64_t> &strideInts,
10711123
SmallVectorImpl<int64_t> &paddingInts,
10721124
SmallVector<AffineMap> &indexingMapsAvg,
@@ -1075,11 +1127,6 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
10751127

10761128
constexpr int avgPoolDims = getAvgPoolNumOfDims<OpTy>();
10771129

1078-
bool noPadding = llvm::all_of(paddingInts, [](int64_t p) { return p == 0; });
1079-
if (countIncludePad || noPadding) {
1080-
// These cases are not handled here.
1081-
return std::nullopt;
1082-
}
10831130
if (avgPoolDims < 1) {
10841131
return rewriter.notifyMatchFailure(
10851132
op, "Unexpected type. Only expected AtenAvgPool1dOp, AtenAvgPool2dOp, "
@@ -1088,8 +1135,8 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
10881135

10891136
Type resultElementType = cast<RankedTensorType>(resultType).getElementType();
10901137

1091-
PoolSizeCalculator<avgPoolDims> poolSizeCalculator(self, sumPool, rewriter,
1092-
loc);
1138+
PoolSizeCalculator<avgPoolDims> poolSizeCalculator(
1139+
self, sumPool, countIncludePad, rewriter, loc);
10931140

10941141
// AtenAvgPool2/3dOp has an optional divisor_override
10951142
// attribute while AtenAvgPool1dOp does not.
@@ -1110,7 +1157,7 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
11101157
[&](OpBuilder &b, Location loc, ValueRange args) {
11111158
if (!poolSize) {
11121159
poolSize = poolSizeCalculator.getPoolSize(
1113-
b, kernelSizeIntValues, strideInts, paddingInts);
1160+
b, kernelDimSizes, strideInts, paddingInts);
11141161
}
11151162
Value divisor =
11161163
convertScalarToDtype(b, loc, poolSize, resultElementType);
@@ -1128,21 +1175,21 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
11281175

11291176
template <typename OpTy, typename PoolingOpTy, int Dim>
11301177
LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1131-
createAvgPoolValueCountIncludePadTrueCase(
1178+
createAveragePoolValueWithRegularDivisor(
11321179
OpTy op, typename OpTy::Adaptor &adaptor,
11331180
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
11341181
Value outputTensor, Type resultType,
1135-
SmallVectorImpl<Value> &kernelSizeIntValues,
1182+
SmallVectorImpl<Value> &kernelDimSizes,
11361183
SmallVector<AffineMap> &indexingMapsAvg,
11371184
SmallVector<utils::IteratorType> &iteratorTypesAvg) {
11381185
Location loc = op->getLoc();
11391186

11401187
Type resultElementType = cast<RankedTensorType>(resultType).getElementType();
11411188

1142-
Value divisor = kernelSizeIntValues[0];
1143-
for (uint32_t i = 1; i < kernelSizeIntValues.size(); ++i) {
1144-
divisor = rewriter.createOrFold<arith::MulIOp>(loc, divisor,
1145-
kernelSizeIntValues[i]);
1189+
Value divisor = kernelDimSizes[0];
1190+
for (uint32_t i = 1; i < kernelDimSizes.size(); ++i) {
1191+
divisor =
1192+
rewriter.createOrFold<arith::MulIOp>(loc, divisor, kernelDimSizes[i]);
11461193
}
11471194
// Only average pooling 2D/3D have optional divisor override.
11481195
if constexpr (!std::is_same<OpTy, AtenAvgPool1dOp>()) {

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,12 @@
650650
"Aten_EmbeddingBagExample_basic",
651651
"Aten_TrilinearModuleVaryingRanks_basic",
652652
"Aten_TrilinearModuleZerodDimBug_basic",
653+
"AvgPool2dCeilPadNonUnitaryStrides_basic",
654+
"AvgPool2dCeilNoPadStridedIncludePadding_basic",
655+
"AvgPool2dCeilPaddingStridedIncludePadding_basic",
656+
"AvgPool2dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
657+
"AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
658+
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
653659
"AvgPool2dDivisorOverrideModule_basic",
654660
"BernoulliTensorModule_basic",
655661
"BincountMinlengthModule_basic",
@@ -2800,6 +2806,10 @@
28002806
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
28012807
"AvgPool2dSingleIntTupleParamsModule_basic",
28022808
"AvgPool2dWithoutPadModule_basic",
2809+
"AvgPool1dNoPadCeilPadNotIncluded_basic",
2810+
"AvgPool1dPadCeilPadNotIncluded_basic",
2811+
"AvgPool2dDiffKernelsStridesPadCeilPadNotIncluded_basic",
2812+
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
28032813
"BatchMlpLayerModule_basic",
28042814
"BincountMinlengthModule_basic",
28052815
"BincountModule_basic",
@@ -3549,6 +3559,13 @@
35493559
"AvgPool1dIntModule_basic",
35503560
"AvgPool1dStaticModule_basic",
35513561
"AvgPool2dCeilModeTrueModule_basic",
3562+
"AvgPool1dNoPadCeilPadNotIncluded_basic",
3563+
"AvgPool1dPadCeilPadNotIncluded_basic",
3564+
"AvgPool2dCeilPaddingStridedIncludePadding_basic",
3565+
"AvgPool2dCeilPaddingUnitaryStrideIncludePadding_basic",
3566+
"AvgPool2dFloorPaddingUnitaryStrideIncludePadding_basic",
3567+
"AvgPool3dDiffKernelsStridesNoPadCeilPadNotIncluded_basic",
3568+
"AvgPool3dDiffKernelsStridesPadCeilPadNotIncluded_basic",
35523569
"AvgPool2dDivisorOverrideModule_basic",
35533570
"AvgPool2dFloatModule_basic",
35543571
"AvgPool2dIntModule_basic",
@@ -3955,6 +3972,8 @@
39553972
"AtenKthvalueFloat64Module_basic",
39563973
"AtenKthvalueKeepDimModule_basic",
39573974
"AtenKthvalueModule_basic",
3975+
"AvgPool2dCeilNoPadUnitaryStrides_basic",
3976+
"AvgPool2dCeilPadNonUnitaryStrides_basic",
39583977
"AvgPool2dCountIncludePadFalseStaticModule_basic",
39593978
"AvgPool3dStaticModule_basic",
39603979
"Conv_Transpose1dModule_basic",

0 commit comments

Comments
 (0)