@@ -862,7 +862,7 @@ namespace {
862
862
// used in the divisor of the average pooling operator.
863
863
template <int NumOfDims> class PoolSizeCalculator {
864
864
public:
865
- PoolSizeCalculator (Value self, Value sumPool,
865
+ PoolSizeCalculator (Value self, Value sumPool, bool countIncludePad,
866
866
ConversionPatternRewriter &rewriter, Location loc);
867
867
868
868
// The algorithm for computing the divisor with
@@ -877,36 +877,37 @@ template <int NumOfDims> class PoolSizeCalculator {
877
877
SmallVectorImpl<int64_t > &paddingInts);
878
878
879
879
private:
880
- int64_t DimSizeFromSumPoolType [NumOfDims];
881
- Value InputSpatialDimValues [NumOfDims];
880
+ int64_t SumPoolTypeDimIndex [NumOfDims];
881
+ Value InputSpatialDimSizes [NumOfDims];
882
882
Location location;
883
+ bool isCountIncludePad;
883
884
};
884
885
885
886
} // namespace
886
887
887
888
template <int NumOfDims>
888
889
PoolSizeCalculator<NumOfDims>::PoolSizeCalculator(
889
- Value self, Value sumPool, ConversionPatternRewriter &rewriter ,
890
- Location loc)
891
- : location(loc) {
890
+ Value self, Value sumPool, bool countIncludePad ,
891
+ ConversionPatternRewriter &rewriter, Location loc)
892
+ : location(loc), isCountIncludePad(countIncludePad) {
892
893
auto selfType = cast<RankedTensorType>(self.getType ());
893
894
const int64_t selfRank = selfType.getRank ();
894
895
RankedTensorType sumPoolType = cast<RankedTensorType>(sumPool.getType ());
895
896
const int64_t rank = sumPoolType.getRank ();
896
897
897
898
// Store dimensions in this order:
898
- // 0 => width , 1 => height, 2 => depth
899
+ // 0 => depth , 1 => height, 2 => width
899
900
for (int i = 0 ; i < NumOfDims; ++i) {
900
- int64_t DimSizeFromSelfType = toPositiveDim (-(i + 1 ), selfRank);
901
- InputSpatialDimValues[i ] =
902
- getDimOp (rewriter, location, self, DimSizeFromSelfType );
903
- DimSizeFromSumPoolType[i ] = toPositiveDim (-(i + 1 ), rank);
901
+ int64_t inputSpatialDimIndex = toPositiveDim (-(i + 1 ), selfRank);
902
+ InputSpatialDimSizes[NumOfDims - i - 1 ] =
903
+ getDimOp (rewriter, location, self, inputSpatialDimIndex );
904
+ SumPoolTypeDimIndex[NumOfDims - i - 1 ] = toPositiveDim (-(i + 1 ), rank);
904
905
}
905
906
}
906
907
907
908
template <int NumOfDims>
908
909
Value PoolSizeCalculator<NumOfDims>::getPoolSize(
909
- OpBuilder &b, SmallVectorImpl<Value> &kernelSizeIntValues ,
910
+ OpBuilder &b, SmallVectorImpl<Value> &kernelDimSizes ,
910
911
SmallVectorImpl<int64_t > &strideInts,
911
912
SmallVectorImpl<int64_t > &paddingInts) {
912
913
Value poolSize;
@@ -921,19 +922,20 @@ Value PoolSizeCalculator<NumOfDims>::getPoolSize(
921
922
// Dim below stands for spatial dimension. Prior to the February 2025
922
923
// change, these variables used "height" and "width" (or "h" and "w")
923
924
// in these intermediate variables instead of "Dim".
925
+
924
926
Value IndexODim =
925
927
b.create <linalg::IndexOp>(location,
926
- /* value=*/ DimSizeFromSumPoolType [i]);
928
+ /* value=*/ SumPoolTypeDimIndex [i]);
927
929
Value ODim = castIndexToInt64 (b, location, IndexODim);
928
930
Value DDim = b.createOrFold <arith::ConstantOp>(
929
931
location, b.getI64IntegerAttr (strideInts[i]));
930
932
Value PadDim = b.createOrFold <arith::ConstantOp>(
931
933
location, b.getI64IntegerAttr (paddingInts[i]));
932
934
Value ODimDDim = b.createOrFold <arith::MulIOp>(location, ODim, DDim);
933
935
Value IDim0 = b.createOrFold <arith::SubIOp>(location, ODimDDim, PadDim);
934
- Value IDim = castIndexToInt64 (b, location, InputSpatialDimValues [i]);
936
+ Value IDim = castIndexToInt64 (b, location, InputSpatialDimSizes [i]);
935
937
Value IDim0KDim =
936
- b.createOrFold <arith::AddIOp>(location, IDim0, kernelSizeIntValues [i]);
938
+ b.createOrFold <arith::AddIOp>(location, IDim0, kernelDimSizes [i]);
937
939
Value IDimPadDim = b.createOrFold <arith::AddIOp>(location, IDim, PadDim);
938
940
Value IDim1 =
939
941
b.createOrFold <arith::MinSIOp>(location, IDim0KDim, IDimPadDim);
@@ -943,11 +945,15 @@ Value PoolSizeCalculator<NumOfDims>::getPoolSize(
943
945
Value IDim1Clamped = b.createOrFold <arith::MinSIOp>(location, IDim1, IDim);
944
946
Value IDim1_IDim0_Clamped =
945
947
b.createOrFold <arith::SubIOp>(location, IDim1Clamped, IDim0Clamped);
948
+
949
+ Value poolSizeDim =
950
+ !isCountIncludePad
951
+ ? IDim1_IDim0_Clamped
952
+ : b.createOrFold <arith::SubIOp>(location, IDim1, IDim0);
946
953
if (i == 0 ) {
947
- poolSize = IDim1_IDim0_Clamped ;
954
+ poolSize = poolSizeDim ;
948
955
} else {
949
- poolSize = b.createOrFold <arith::MulIOp>(location, poolSize,
950
- IDim1_IDim0_Clamped);
956
+ poolSize = b.createOrFold <arith::MulIOp>(location, poolSize, poolSizeDim);
951
957
}
952
958
}
953
959
return poolSize;
@@ -963,26 +969,35 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern<OpTy> {
963
969
matchAndRewrite (OpTy op, typename OpTy::Adaptor adaptor,
964
970
ConversionPatternRewriter &rewriter) const override ;
965
971
966
- // Creates the average pooling operation value when the
967
- // count_include_pad parameter is equal to false.
968
- static std::optional<LogicalResult>
969
- createAvgPoolValueCountIncludePadFalseCase (
970
- bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor,
971
- ConversionPatternRewriter &rewriter, Value self, Value sumPool,
972
- Value outputTensor, Type resultType,
973
- SmallVectorImpl<Value> &kernelSizeIntValues,
972
+ // If the condition below is true, the divisor total must subtract the
973
+ // elements not counted (clamped divisor count). If false, the divisor
974
+ // is just the product of kernel dimensions.
975
+ static bool
976
+ doesAvgPoolDivisorNeedsClamping (bool ceilMode, bool countIncludePad,
977
+ SmallVectorImpl<int64_t > &strideInts,
978
+ SmallVectorImpl<int64_t > &paddingInts);
979
+
980
+ // Creates the average pooling operation value with a clamped
981
+ // divisor. The clamped divisor is the product of kernel
982
+ // dimensions minus the elements not counted; e.g., padding
983
+ // and ceiling mode implicit padding.
984
+ static LogicalResult createAveragePoolValueWithClampedDivisor (
985
+ bool ceilMode, bool countIncludePad, OpTy op,
986
+ typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter,
987
+ Value self, Value sumPool, Value outputTensor, Type resultType,
988
+ SmallVectorImpl<Value> &kernelDimSizes,
974
989
SmallVectorImpl<int64_t > &strideInts,
975
990
SmallVectorImpl<int64_t > &paddingInts,
976
991
SmallVector<AffineMap> &indexingMapsAvg,
977
992
SmallVector<utils::IteratorType> &iteratorTypesAvg);
978
993
979
- // Creates the average pooling operation value when the
980
- // count_include_pad parameter is equal to true .
981
- static LogicalResult createAvgPoolValueCountIncludePadTrueCase (
994
+ // Creates the average pooling operation value with a
995
+ // regular divisor; i.e., the product of kernel dimensions .
996
+ static LogicalResult createAveragePoolValueWithRegularDivisor (
982
997
OpTy op, typename OpTy::Adaptor &adaptor,
983
998
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
984
999
Value outputTensor, Type resultType,
985
- SmallVectorImpl<Value> &kernelSizeIntValues ,
1000
+ SmallVectorImpl<Value> &kernelDimSizes ,
986
1001
SmallVector<AffineMap> &indexingMapsAvg,
987
1002
SmallVector<utils::IteratorType> &iteratorTypesAvg);
988
1003
};
@@ -1046,27 +1061,64 @@ LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::matchAndRewrite(
1046
1061
SmallVector<utils::IteratorType> iteratorTypesAvg (
1047
1062
Dim + 2 , utils::IteratorType::parallel);
1048
1063
1049
- auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase (
1050
- countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor,
1051
- resultType, kernelSizeIntValues, strideInts, paddingInts, indexingMapsAvg,
1052
- iteratorTypesAvg);
1053
- if (divisorOpResult)
1054
- return *divisorOpResult;
1064
+ if (doesAvgPoolDivisorNeedsClamping (ceilMode, countIncludePad, strideInts,
1065
+ paddingInts)) {
1066
+ return createAveragePoolValueWithClampedDivisor (
1067
+ ceilMode, countIncludePad, op, adaptor, rewriter, self, sumPool,
1068
+ outputTensor, resultType, kernelSizeIntValues, strideInts, paddingInts,
1069
+ indexingMapsAvg, iteratorTypesAvg);
1070
+ }
1055
1071
1056
- return createAvgPoolValueCountIncludePadTrueCase (
1072
+ return createAveragePoolValueWithRegularDivisor (
1057
1073
op, adaptor, rewriter, self, sumPool, outputTensor, resultType,
1058
1074
kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg);
1075
+ }
1059
1076
1060
- return success ();
1077
+ template <typename OpTy, typename PoolingOpTy, int Dim>
1078
+ bool ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1079
+ doesAvgPoolDivisorNeedsClamping (bool ceilMode, bool countIncludePad,
1080
+ SmallVectorImpl<int64_t > &strideInts,
1081
+ SmallVectorImpl<int64_t > &paddingInts) {
1082
+ // Determines whether the average pooling divisor needs to be clamped
1083
+ // (i.e., adjusted to exclude padded or out-of-bounds elements).
1084
+ //
1085
+ // There are two primary cases where clamping is needed:
1086
+ // 1. Padding with count_include_pad == false:
1087
+ // - If padding is applied (padding != 0) and count_include_pad is false,
1088
+ // then padding elements are *excluded* from the divisor, effectively
1089
+ // clamping the divisor to the number of valid input elements.
1090
+ //
1091
+ // 2. Ceil mode with non-unit stride:
1092
+ // - When ceil_mode is enabled, output dimensions are rounded up,
1093
+ // potentially
1094
+ // creating pooling windows that extend beyond the input tensor bounds.
1095
+ // PyTorch handles this by implicitly adding zero-padding outside the
1096
+ // tensor, but these extra (implicit) padded elements are *not* included
1097
+ // in the divisor. This behavior is independent of the count_include_pad
1098
+ // flag.
1099
+ // - If all strides are 1, ceil_mode will not produce fractional divisions,
1100
+ // so the windows will not extend beyond bounds, and no clamping occurs.
1101
+ //
1102
+ // Reference: PyTorch AvgPool2d documentation and formula for H_out/W_out:
1103
+ // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
1104
+ //
1105
+ // See torch.nn.AvgPool2d E2E tests for comprehensive coverage.
1106
+
1107
+ bool hasPadding =
1108
+ !llvm::all_of (paddingInts, [](int64_t p) { return p == 0 ; });
1109
+ bool allStridesUnitary =
1110
+ llvm::all_of (strideInts, [](int64_t s) { return s == 1 ; });
1111
+
1112
+ return (!countIncludePad && hasPadding) || (ceilMode && !allStridesUnitary);
1061
1113
}
1062
1114
1063
1115
template <typename OpTy, typename PoolingOpTy, int Dim>
1064
- std::optional< LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1065
- createAvgPoolValueCountIncludePadFalseCase (
1066
- bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor ,
1067
- ConversionPatternRewriter &rewriter, Value self, Value sumPool ,
1068
- Value outputTensor, Type resultType,
1069
- SmallVectorImpl<Value> &kernelSizeIntValues ,
1116
+ LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1117
+ createAveragePoolValueWithClampedDivisor (
1118
+ bool ceilMode, bool countIncludePad, OpTy op ,
1119
+ typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter ,
1120
+ Value self, Value sumPool, Value outputTensor, Type resultType,
1121
+ SmallVectorImpl<Value> &kernelDimSizes ,
1070
1122
SmallVectorImpl<int64_t > &strideInts,
1071
1123
SmallVectorImpl<int64_t > &paddingInts,
1072
1124
SmallVector<AffineMap> &indexingMapsAvg,
@@ -1075,11 +1127,6 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1075
1127
1076
1128
constexpr int avgPoolDims = getAvgPoolNumOfDims<OpTy>();
1077
1129
1078
- bool noPadding = llvm::all_of (paddingInts, [](int64_t p) { return p == 0 ; });
1079
- if (countIncludePad || noPadding) {
1080
- // These cases are not handled here.
1081
- return std::nullopt;
1082
- }
1083
1130
if (avgPoolDims < 1 ) {
1084
1131
return rewriter.notifyMatchFailure (
1085
1132
op, " Unexpected type. Only expected AtenAvgPool1dOp, AtenAvgPool2dOp, "
@@ -1088,8 +1135,8 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1088
1135
1089
1136
Type resultElementType = cast<RankedTensorType>(resultType).getElementType ();
1090
1137
1091
- PoolSizeCalculator<avgPoolDims> poolSizeCalculator (self, sumPool, rewriter,
1092
- loc);
1138
+ PoolSizeCalculator<avgPoolDims> poolSizeCalculator (
1139
+ self, sumPool, countIncludePad, rewriter, loc);
1093
1140
1094
1141
// AtenAvgPool2/3dOp has an optional divisor_override
1095
1142
// attribute while AtenAvgPool1dOp does not.
@@ -1110,7 +1157,7 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1110
1157
[&](OpBuilder &b, Location loc, ValueRange args) {
1111
1158
if (!poolSize) {
1112
1159
poolSize = poolSizeCalculator.getPoolSize (
1113
- b, kernelSizeIntValues , strideInts, paddingInts);
1160
+ b, kernelDimSizes , strideInts, paddingInts);
1114
1161
}
1115
1162
Value divisor =
1116
1163
convertScalarToDtype (b, loc, poolSize, resultElementType);
@@ -1128,21 +1175,21 @@ std::optional<LogicalResult> ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1128
1175
1129
1176
template <typename OpTy, typename PoolingOpTy, int Dim>
1130
1177
LogicalResult ConvertAtenAvgPoolOp<OpTy, PoolingOpTy, Dim>::
1131
- createAvgPoolValueCountIncludePadTrueCase (
1178
+ createAveragePoolValueWithRegularDivisor (
1132
1179
OpTy op, typename OpTy::Adaptor &adaptor,
1133
1180
ConversionPatternRewriter &rewriter, Value self, Value sumPool,
1134
1181
Value outputTensor, Type resultType,
1135
- SmallVectorImpl<Value> &kernelSizeIntValues ,
1182
+ SmallVectorImpl<Value> &kernelDimSizes ,
1136
1183
SmallVector<AffineMap> &indexingMapsAvg,
1137
1184
SmallVector<utils::IteratorType> &iteratorTypesAvg) {
1138
1185
Location loc = op->getLoc ();
1139
1186
1140
1187
Type resultElementType = cast<RankedTensorType>(resultType).getElementType ();
1141
1188
1142
- Value divisor = kernelSizeIntValues [0 ];
1143
- for (uint32_t i = 1 ; i < kernelSizeIntValues .size (); ++i) {
1144
- divisor = rewriter. createOrFold <arith::MulIOp>(loc, divisor,
1145
- kernelSizeIntValues [i]);
1189
+ Value divisor = kernelDimSizes [0 ];
1190
+ for (uint32_t i = 1 ; i < kernelDimSizes .size (); ++i) {
1191
+ divisor =
1192
+ rewriter. createOrFold <arith::MulIOp>(loc, divisor, kernelDimSizes [i]);
1146
1193
}
1147
1194
// Only average pooling 2D/3D have optional divisor override.
1148
1195
if constexpr (!std::is_same<OpTy, AtenAvgPool1dOp>()) {
0 commit comments