From b4081d874c06408c3e4718e3070ac736e2c24354 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Wed, 19 Feb 2025 21:16:31 -0500 Subject: [PATCH 1/3] Extend the PyTorch avg_pool linalg lowering algorithm for the case where count_include_pad = false --- lib/Conversion/TorchToLinalg/Pooling.cpp | 384 +++++++++++------- projects/pt1/e2e_testing/xfail_sets.py | 2 + .../torch_mlir_e2e_test/test_suite/pooling.py | 108 +++++ test/Conversion/TorchToLinalg/pooling.mlir | 130 ++++++ 4 files changed, 473 insertions(+), 151 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 90b5b2af77a8..51930f1410bd 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -18,6 +18,7 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include using namespace mlir; using namespace mlir::torch; @@ -829,6 +830,228 @@ class ConvertAtenMaxUnpool3dOp final }; } // namespace +namespace { +// The following structures and the getNumOfDims method +// are used to get the number of dimensions from the +// average pooling type at compile time. +template struct AtenAvgPoolTypeNumOfDims { + static constexpr int getNumOfDims() { return -1; } +}; +template <> struct AtenAvgPoolTypeNumOfDims { + static constexpr int getNumOfDims() { return 1; } +}; +template <> struct AtenAvgPoolTypeNumOfDims { + static constexpr int getNumOfDims() { return 2; } +}; +template <> struct AtenAvgPoolTypeNumOfDims { + static constexpr int getNumOfDims() { return 3; } +}; +template constexpr int getAvgPoolNumOfDims() { + return AtenAvgPoolTypeNumOfDims::getNumOfDims(); +} +} // namespace + +namespace { +// This is a helper class to create the pooling size value +// used in the divisor of the average pooling operator. +template class PoolSizeCalculator { +public: + PoolSizeCalculator(Value self, Value sumPool, + ConversionPatternRewriter &rewriter, Location loc); + + // The algorithm for computing the divisor with + // count_include_pad equal is mainly based on pytorch + // implementation. The following code is comment + // with pytorch code. + // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 + // Dim below stands for spatial dimension. It replaces the + // height and width labels in variables. + Value getPoolSize(OpBuilder &b, SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts); + +private: + int64_t DimSizeFromSumPoolType[NumOfDims]; + Value InputSpatialDimValues[NumOfDims]; + Location location; +}; + +} // namespace + +template +PoolSizeCalculator::PoolSizeCalculator( + Value self, Value sumPool, ConversionPatternRewriter &rewriter, + Location loc) + : location(loc) { + auto selfType = cast(self.getType()); + const int64_t selfRank = selfType.getRank(); + RankedTensorType sumPoolType = cast(sumPool.getType()); + const int64_t rank = sumPoolType.getRank(); + + // Store dimensions in this order: + // 0 => width, 1 => height, 2 => depth + for (int i = 0; i < NumOfDims; ++i) { + int64_t DimSizeFromSelfType = toPositiveDim(-(i + 1), selfRank); + InputSpatialDimValues[i] = + getDimOp(rewriter, location, self, DimSizeFromSelfType); + DimSizeFromSumPoolType[i] = toPositiveDim(-(i + 1), rank); + } +} + +template +Value PoolSizeCalculator::getPoolSize( + OpBuilder &b, SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts) { + Value poolSize; + + Value cstZero = + b.createOrFold(location, b.getI64IntegerAttr(0)); + + for (int i = 0; i < NumOfDims; ++i) { + // See the link below for the PyTorch implementation where + // this is derived from: + // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 + // Dim below stands for spatial dimension. Prior to the February + // 2025 change, these variables used "height" and "width" (or + // "h" and "w") in these intermediate variables instead of "Dim". + Value IndexODim = + b.createOrFold(location, + /*value=*/DimSizeFromSumPoolType[i]); + Value ODim = castIndexToInt64(b, location, IndexODim); + Value DDim = b.createOrFold( + location, b.getI64IntegerAttr(strideInts[i])); + Value PadDim = b.createOrFold( + location, b.getI64IntegerAttr(paddingInts[i])); + Value ODimDDim = b.createOrFold(location, ODim, DDim); + Value IDim0 = b.createOrFold(location, ODimDDim, PadDim); + Value IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]); + Value IDim0KDim = + b.createOrFold(location, IDim0, kernelSizeIntValues[i]); + Value IDimPadDim = b.createOrFold(location, IDim, PadDim); + Value IDim1 = + b.createOrFold(location, IDim0KDim, IDimPadDim); + + Value IDim0Clamped = + b.createOrFold(location, IDim0, cstZero); + Value IDim1Clamped = b.createOrFold(location, IDim1, IDim); + Value IDim1_IDim0_Clamped = + b.createOrFold(location, IDim1Clamped, IDim0Clamped); + if (i == 0) { + poolSize = IDim1_IDim0_Clamped; + } else { + poolSize = b.createOrFold(location, poolSize, + IDim1_IDim0_Clamped); + } + } + return poolSize; +} + +// Creates the average pooling operation value when the +// count_include_pad parameter is equal to false. +template +static std::optional createAvgPoolValueCountIncludePadFalseCase( + bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter, Value self, Value sumPool, + Value outputTensor, Type resultType, + SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, + SmallVector &indexingMapsAvg, + SmallVector &iteratorTypesAvg) { + Location loc = op->getLoc(); + + constexpr int avgPoolDims = getAvgPoolNumOfDims(); + + bool noPadding = llvm::all_of(paddingInts, [](int64_t p) { return p == 0; }); + if (countIncludePad || noPadding) { + // These cases are not handled here. + return std::nullopt; + } + if (avgPoolDims < 1) { + return rewriter.notifyMatchFailure( + op, "Unexpected type. Only expected AtenAvgPool1dOp, AtenAvgPool2dOp, " + "and AtenAvgPool3dOp."); + } + + Type resultElementType = cast(resultType).getElementType(); + + PoolSizeCalculator poolSizeCalculator(self, sumPool, rewriter, + loc); + Value avgPool = + rewriter + .create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + auto poolSize = poolSizeCalculator.getPoolSize( + b, kernelSizeIntValues, strideInts, paddingInts); + // AtenAvgPool2/3dOp has an optional divisor_override + // attribute while AtenAvgPool1dOp does not. + if constexpr (avgPoolDims > 1) { + if (!isa(op.getDivisorOverride().getType())) + poolSize = adaptor.getDivisorOverride(); + } + Value divisor = + convertScalarToDtype(b, loc, poolSize, resultElementType); + Value avg; + if (isa(resultElementType)) + avg = b.createOrFold(loc, args[0], divisor); + else if (isa(resultElementType)) + avg = b.createOrFold(loc, args[0], divisor); + b.createOrFold(loc, avg); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, resultType, avgPool); + return success(); +} + +// Creates the average pooling operation value when the +// count_include_pad parameter is equal to true. +template +static LogicalResult createAvgPoolValueCountIncludePadTrueCase( + OpTy op, typename OpTy::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, Value self, Value sumPool, + Value outputTensor, Type resultType, + SmallVectorImpl &kernelSizeIntValues, + SmallVector &indexingMapsAvg, + SmallVector &iteratorTypesAvg) { + Location loc = op->getLoc(); + + Type resultElementType = cast(resultType).getElementType(); + + Value divisor = kernelSizeIntValues[0]; + for (uint32_t i = 1; i < kernelSizeIntValues.size(); ++i) { + divisor = rewriter.createOrFold(loc, divisor, + kernelSizeIntValues[i]); + } + // Only average pooling 2D/3D have optional divisor override. + if constexpr (!std::is_same()) { + divisor = isa(op.getDivisorOverride().getType()) + ? divisor + : adaptor.getDivisorOverride(); + } + divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); + + Value avgPool = + rewriter + .create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value avg; + if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + else if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + b.create(loc, avg); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, resultType, avgPool); + return success(); +} + namespace { template class ConvertAtenAvgPoolOp : public OpConversionPattern { @@ -892,159 +1115,18 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { 2, rewriter.getMultiDimIdentityMap(Dim + 2)); SmallVector iteratorTypesAvg( Dim + 2, utils::IteratorType::parallel); - Value avgPool; - Value divisor; - // Case1: AtenAvgPool1d/2dOp with countIncludePad=false support. - if constexpr (std::is_same()) { - auto selfType = cast(self.getType()); - const int64_t selfRank = selfType.getRank(); - int64_t wDim = toPositiveDim(-1, selfRank); - int64_t hDim = toPositiveDim(-2, selfRank); - Value inputHeight = getDimOp(rewriter, loc, self, hDim); - Value inputWidth = getDimOp(rewriter, loc, self, wDim); - RankedTensorType sumPoolType = cast(sumPool.getType()); - const int64_t rank = sumPoolType.getRank(); - int dimH = toPositiveDim(-2, rank); - int dimW = toPositiveDim(-1, rank); - avgPool = - rewriter - .create( - loc, outputTensor.getType(), sumPool, outputTensor, - /*indexingMaps=*/indexingMapsAvg, - /*iteratorTypes=*/iteratorTypesAvg, - [&](OpBuilder &b, Location loc, ValueRange args) { - // The algorithm for computing the divisor with - // count_include_pad is manily based on pytorch - // implementation. The following code is comment - // with pytorch code. - // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 - Value indexOh = - b.create(loc, /*value=*/dimH); - Value oh = castIndexToInt64(b, loc, indexOh); - Value indexOw = - b.create(loc, /*value=*/dimW); - Value ow = castIndexToInt64(b, loc, indexOw); - - // int64_t ih0 = oh * dH - padH; - Value dH = rewriter.create( - loc, rewriter.getI64IntegerAttr(strideInts[0])); - Value padH = rewriter.create( - loc, rewriter.getI64IntegerAttr(paddingInts[0])); - Value ohDH = b.create(loc, oh, dH); - Value ih0 = b.create(loc, ohDH, padH); - // int64_t iw0 = ow * dW - padW; - Value dW = rewriter.create( - loc, rewriter.getI64IntegerAttr(strideInts[1])); - Value padW = rewriter.create( - loc, rewriter.getI64IntegerAttr(paddingInts[1])); - Value owDW = b.create(loc, ow, dW); - Value iw0 = b.create(loc, owDW, padW); - // int64_t ih1 = std::min(ih0 + kH, input_height + padH); - Value ih = castIndexToInt64(b, loc, inputHeight); - Value ih0KH = b.create( - loc, ih0, kernelSizeIntValues[0]); - Value ihPadH = b.create(loc, ih, padH); - Value ih1 = b.create(loc, ih0KH, ihPadH); - // int64_t iw1 = std::min(iw0 + kW, input_width + padW); - Value iw = castIndexToInt64(b, loc, inputWidth); - Value iw0KW = b.create( - loc, iw0, kernelSizeIntValues[1]); - Value iwPadW = b.create(loc, iw, padW); - Value iw1 = b.create(loc, iw0KW, iwPadW); - // int64_t pool_size = (ih1 - ih0) * (iw1 - iw0); - Value ih1Ih0 = b.create(loc, ih1, ih0); - Value iw1Iw0 = b.create(loc, iw1, iw0); - Value poolSize = - b.create(loc, ih1Ih0, iw1Iw0); - // ih0 = std::max(ih0, 0); - Value cstZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value ih0Clamped = - b.create(loc, ih0, cstZero); - // iw0 = std::max(iw0, 0); - Value iw0Clamped = - b.create(loc, iw0, cstZero); - // ih1 = std::min(ih1, input_height); - Value ih1Clamped = b.create(loc, ih1, ih); - // iw1 = std::min(iw1, input_width); - Value iw1Clamped = b.create(loc, iw1, iw); - // if (divisor_override.has_value()) { - // divisor = divisor_override.value(); - // } else { - // if(count_include_pad) { - // divisor = pool_size; - // } else { - // divisor = (ih1 - ih0) * (iw1 - iw0); - // } - // } - if (countIncludePad) { - divisor = convertScalarToDtype(b, loc, poolSize, - resultElementType); - } else { - Value ih1_ih0 = - b.create(loc, ih1Clamped, ih0Clamped); - Value iw1_iw0 = - b.create(loc, iw1Clamped, iw0Clamped); - divisor = b.create(loc, ih1_ih0, iw1_iw0); - } - // AtenAvgPool2/3dOp has an optional divisor_override - // attribute while AtenAvgPool1dOp does not. - if constexpr (std::is_same()) { - if (!isa( - op.getDivisorOverride().getType())) - divisor = adaptor.getDivisorOverride(); - } - - divisor = convertScalarToDtype(b, loc, divisor, - resultElementType); - Value avg; - if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - else if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - b.create(loc, avg); - }) - .getResult(0); - rewriter.replaceOpWithNewOp(op, resultType, avgPool); - return success(); - } - // TODO: Add support for count_include_pad equal to `False` in - // AtenAvgPool1/3dOp. - if (!countIncludePad && - !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) { - return rewriter.notifyMatchFailure( - op, "unimplemented: count_include_pad is expected to be true for " - "AtenAvgPool3dOp"); - } + auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase( + countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor, + resultType, kernelSizeIntValues, strideInts, paddingInts, + indexingMapsAvg, iteratorTypesAvg); + if (divisorOpResult) + return *divisorOpResult; + + return createAvgPoolValueCountIncludePadTrueCase( + op, adaptor, rewriter, self, sumPool, outputTensor, resultType, + kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg); - // Case2: AtenAvgPool1/3dOp without count_include_pad equal to `False`. - divisor = kernelSizeIntValues[0]; - for (uint32_t i = 1; i < kernelSizeIntValues.size(); i++) { - divisor = - rewriter.create(loc, divisor, kernelSizeIntValues[i]); - } - if constexpr (!std::is_same()) { - divisor = isa(op.getDivisorOverride().getType()) - ? divisor - : adaptor.getDivisorOverride(); - } - divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); - avgPool = rewriter - .create( - loc, outputTensor.getType(), sumPool, outputTensor, - /*indexingMaps=*/indexingMapsAvg, - /*iteratorTypes=*/iteratorTypesAvg, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value avg; - if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - else if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - b.create(loc, avg); - }) - .getResult(0); - rewriter.replaceOpWithNewOp(op, resultType, avgPool); return success(); } }; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3dd78cc011f6..1d7dd30b56c9 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3393,6 +3393,8 @@ "AtenKthvalueKeepDimModule_basic", "AtenKthvalueModule_basic", "AvgPool3dStaticModule_basic", + "AvgPool3dCountIncludePadFalse_basic", + "AvgPool3dCountIncludePadFalseWithoutPadding_basic", "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", "Conv_Transpose2dStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index e2eaa4cfd0fe..ce7d2e2bc42e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1460,6 +1460,66 @@ def AvgPool3dStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 4, 4, 4, low=-1)) +class AvgPool3dCountIncludePadFalse(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap3d = torch.nn.AvgPool3d( + kernel_size=[3, 3, 3], + stride=[1, 1, 1], + padding=[1, 1, 1], + ceil_mode=False, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([3, 3, 12, 12, 12], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap3d(x) + + +@register_test_case(module_factory=lambda: AvgPool3dCountIncludePadFalse()) +def AvgPool3dCountIncludePadFalse_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3, 12, 12, 12, low=-1)) + + +class AvgPool3dCountIncludePadFalseWithoutPadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap3d = torch.nn.AvgPool3d( + kernel_size=[3, 3, 3], + stride=[1, 1, 1], + padding=[0, 0, 0], + ceil_mode=False, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([3, 3, 12, 12, 12], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap3d(x) + + +@register_test_case( + module_factory=lambda: AvgPool3dCountIncludePadFalseWithoutPadding() +) +def AvgPool3dCountIncludePadFalseWithoutPadding_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3, 12, 12, 12, low=-1)) + + # ============================================================================== @@ -1532,6 +1592,54 @@ def AvgPool1dStaticModule_basic(module, tu: TestUtils): module.forward(tu.randint(2, 4, 20, high=100)) +class AvgPool1dCountIncludePadFalseWithoutPadding(torch.nn.Module): + def __init__(self): + super().__init__() + self.ap1d = torch.nn.AvgPool1d( + kernel_size=3, stride=1, padding=0, ceil_mode=False, count_include_pad=False + ) + + @export + @annotate_args( + [ + None, + ([3, 4, 20], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap1d(x) + + +@register_test_case( + module_factory=lambda: AvgPool1dCountIncludePadFalseWithoutPadding() +) +def AvgPool1dCountIncludePadFalseWithoutPadding_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 20)) + + +class AvgPool1dCountIncludePadFalse(torch.nn.Module): + def __init__(self): + super().__init__() + self.ap1d = torch.nn.AvgPool1d( + kernel_size=3, stride=1, padding=1, ceil_mode=False, count_include_pad=False + ) + + @export + @annotate_args( + [ + None, + ([3, 4, 20], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap1d(x) + + +@register_test_case(module_factory=lambda: AvgPool1dCountIncludePadFalse()) +def AvgPool1dCountIncludePadFalse_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 20)) + + # ============================================================================== diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 558c50c4f08f..53faa1d37d4f 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -95,3 +95,133 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch. // CHECK: } -> tensor return %4 : !torch.vtensor<[?,?,?,?,?],f32> } + +// CHECK-LABEL: func @forward_avg_pool2d +func.func @forward_avg_pool2d(%arg0: !torch.vtensor<[1,3,64,56],f32>) -> !torch.vtensor<[1,3, 61,27],f32> { + // CHECK: linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<[1, 2]> : vector<2xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x3x64x58xf32>, tensor<4x5xf32>) outs(%[[OUT1:.*]] : tensor<1x3x61x27xf32>) -> tensor<1x3x61x27xf32> + // CHECK: linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x61x27xf32>) outs(%[[OUT2:.*]] : tensor<1x3x61x27xf32>) + // CHECK-NEXT: ^bb0(%[[BIIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-NEXT: %[[TMP1:.*]] = arith.divf %[[BIIN1:.*]], %[[CONST1:.*]] : f32 + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x3x61x27xf32> + %none = torch.constant.none + %false = torch.constant.bool false + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[1,3,64,56],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,61,27],f32> + return %3 : !torch.vtensor<[1,3,61,27],f32> +} + +// CHECK-LABEL: func @forward_avg_pool2d_countincludepad_false +func.func @forward_avg_pool2d_countincludepad_false(%arg0: !torch.vtensor<[1,3,64,56],f32>) -> !torch.vtensor<[1,3, 61,27],f32> { + // CHECK: linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<[1, 2]> : vector<2xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x3x64x58xf32>, tensor<4x5xf32>) outs(%[[OUT1:.*]] : tensor<1x3x61x27xf32>) -> tensor<1x3x61x27xf32> + // CHECK: linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x61x27xf32>) outs(%[[OUT2:.*]] : tensor<1x3x61x27xf32>) + // CHECK-NEXT: ^bb0(%[[BIIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-COUNT-4: arith.minsi + // CHECK-COUNT-1: arith.divf + // CHECK: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x3x61x27xf32> + %none = torch.constant.none + %false = torch.constant.bool false + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %false, %none : !torch.vtensor<[1,3,64,56],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,61,27],f32> + return %3 : !torch.vtensor<[1,3,61,27],f32> +} + +// CHECK-LABEL: func @forward_avg_pool3d +func.func @forward_avg_pool3d(%arg0: !torch.vtensor<[1,3,7,64,56],f32>) -> !torch.vtensor<[1,3,4,31,54],f32> { + // CHECK: linalg.pooling_ndhwc_sum {dilations = dense<1> : vector<3xi64>, strides = dense<[1, 2, 1]> : vector<3xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x7x66x58x3xf32>, tensor<4x5x5xf32>) outs(%[[OUT1:.*]] : tensor<1x4x31x54x3xf32>) -> tensor<1x4x31x54x3xf32> + // CHECK: linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x4x31x54xf32>) outs(%[[OUT2:.*]] : tensor<1x3x4x31x54xf32>) + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-NEXT: %[[TMP1:.*]] = arith.divf %[[BIN1:.*]], %[[CONST1:.*]] : f32 + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x3x4x31x54xf32> + %none = torch.constant.none + %false = torch.constant.bool false + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %0 = torch.prim.ListConstruct %int4, %int5, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool3d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[1,3,7,64,56],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,4,31,54],f32> + return %3 : !torch.vtensor<[1,3,4,31,54],f32> +} + +// CHECK-LABEL: func @forward_avg_pool3dd_countincludepad_false +func.func @forward_avg_pool3dd_countincludepad_false(%arg0: !torch.vtensor<[1,3,7,64,56],f32>) -> !torch.vtensor<[1,3,4,31,54],f32> { + // CHECK: linalg.pooling_ndhwc_sum {dilations = dense<1> : vector<3xi64>, strides = dense<[1, 2, 1]> : vector<3xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x7x66x58x3xf32>, tensor<4x5x5xf32>) outs(%[[OUT1:.*]] : tensor<1x4x31x54x3xf32>) -> tensor<1x4x31x54x3xf32> + // CHECK: linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x4x31x54xf32>) outs(%[[OUT2:.*]] : tensor<1x3x4x31x54xf32>) + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-COUNT-6: arith.minsi + // CHECK-COUNT-1: arith.divf + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x3x4x31x54xf32> + %none = torch.constant.none + %false = torch.constant.bool false + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %0 = torch.prim.ListConstruct %int4, %int5, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool3d %arg0, %0, %1, %2, %false, %false, %none : !torch.vtensor<[1,3,7,64,56],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,4,31,54],f32> + return %3 : !torch.vtensor<[1,3,4,31,54],f32> +} + +// CHECK-LABEL: func @forward_avg_pool1d +func.func @forward_avg_pool1d(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,12],f32> { + // CHECK: linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%[[IN1:.*]], %[[IN2:.*]] : tensor<1x512x12xf32>, tensor<1xf32>) outs(%[[OUT1:.*]] : tensor<1x512x12xf32>) -> tensor<1x512x12xf32> + // CHECK: linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[IN3:.*]] : tensor<1x512x12xf32>) outs(%[[OUT2:.*]] : tensor<1x512x12xf32> + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-NEXT: %[[TMP1:.*]] = arith.divf %[[BIN1:.*]], %[[CONST1:.*]] : f32 + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x512x12xf32> + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %true = torch.constant.bool true + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %true : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,12],f32> + return %3 : !torch.vtensor<[1,512,12],f32> +} + +// CHECK-LABEL: func @forward_avg_pool1d_countincludepad_false +func.func @forward_avg_pool1d_countincludepad_false(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,12],f32> { + // CHECK: linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%[[IN1:.*]], %[[IN2:.*]] : tensor<1x512x12xf32>, tensor<1xf32>) outs(%[[OUT1:.*]] : tensor<1x512x12xf32>) -> tensor<1x512x12xf32> + // CHECK: linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[IN3:.*]] : tensor<1x512x12xf32>) outs(%[[OUT2:.*]] : tensor<1x512x12xf32> + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-COUNT-2: arith.minsi + // CHECK-COUNT-1: arith.divf + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x512x12xf32> + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %false : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,12],f32> + return %3 : !torch.vtensor<[1,512,12],f32> +} From 0123524bff6f1f7593a49f4fe0c29ae4966fa09a Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Thu, 20 Feb 2025 08:40:00 -0500 Subject: [PATCH 2/3] Using C++ style encapsulation of methods, rather than C style. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 228 +++++++++++++---------- 1 file changed, 127 insertions(+), 101 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 51930f1410bd..846feba1b125 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -947,17 +947,124 @@ Value PoolSizeCalculator::getPoolSize( return poolSize; } -// Creates the average pooling operation value when the -// count_include_pad parameter is equal to false. -template -static std::optional createAvgPoolValueCountIncludePadFalseCase( - bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter, Value self, Value sumPool, - Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, - SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, - SmallVector &indexingMapsAvg, - SmallVector &iteratorTypesAvg) { +namespace { +template +class ConvertAtenAvgPoolOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; + + // Creates the average pooling operation value when the + // count_include_pad parameter is equal to false. + static std::optional + createAvgPoolValueCountIncludePadFalseCase( + bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter, Value self, Value sumPool, + Value outputTensor, Type resultType, + SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts, + SmallVector &indexingMapsAvg, + SmallVector &iteratorTypesAvg); + + // Creates the average pooling operation value when the + // count_include_pad parameter is equal to true. + static LogicalResult createAvgPoolValueCountIncludePadTrueCase( + OpTy op, typename OpTy::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, Value self, Value sumPool, + Value outputTensor, Type resultType, + SmallVectorImpl &kernelSizeIntValues, + SmallVector &indexingMapsAvg, + SmallVector &iteratorTypesAvg); +}; +} // namespace + +template +LogicalResult ConvertAtenAvgPoolOp::matchAndRewrite( + OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op->getLoc(); + const TypeConverter *typeConverter = this->getTypeConverter(); + Value self = adaptor.getSelf(); + + Type inputElementType = + cast(self.getType()).getElementType(); + Type resultType = typeConverter->convertType(op.getType()); + Type resultElementType = cast(resultType).getElementType(); + + bool ceilMode; + SmallVector kernelSizeIntValues; + SmallVector strideInts, paddingInts, dilationInts(Dim, 1); + if (failed(checkAndGetPoolingParameters(op, rewriter, typeConverter, + ceilMode, kernelSizeIntValues, + strideInts, paddingInts))) + return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); + + // Decode strideInts into strideInts and dilation + if (strideInts.size() == 2 * Dim) { + for (int i = 0; i < Dim; i++) { + dilationInts[i] = strideInts[Dim + i]; + } + for (int i = 0; i < Dim; i++) { + strideInts.pop_back(); + } + } + + bool countIncludePad; + if (!matchPattern(op.getCountIncludePad(), + m_TorchConstantBool(&countIncludePad))) + return rewriter.notifyMatchFailure(op, + "count_include_pad must be a constant"); + + // `sumPool` contains the result of sumpool operation over the input. + Value sumPool, paddedInput; + SmallVector outTensorShape; + if (failed(createPoolingOp( + op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, + /*dimensionality=*/Dim, kernelSizeIntValues, strideInts, paddingInts, + dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape, + paddedInput, sumPool))) + return rewriter.notifyMatchFailure(op, "unable to compute sumpool"); + + // Compute the average of sumPool. + Value outputTensor = rewriter.create( + loc, getAsOpFoldResult(outTensorShape), resultElementType); + SmallVector indexingMapsAvg( + 2, rewriter.getMultiDimIdentityMap(Dim + 2)); + SmallVector iteratorTypesAvg( + Dim + 2, utils::IteratorType::parallel); + + auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase( + countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor, + resultType, kernelSizeIntValues, strideInts, paddingInts, indexingMapsAvg, + iteratorTypesAvg); + if (divisorOpResult) + return *divisorOpResult; + + return createAvgPoolValueCountIncludePadTrueCase( + op, adaptor, rewriter, self, sumPool, outputTensor, resultType, + kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg); + + return success(); +} + +template +std::optional ConvertAtenAvgPoolOp:: + createAvgPoolValueCountIncludePadFalseCase( + bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter, Value self, Value sumPool, + Value outputTensor, Type resultType, + SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts, + SmallVector &indexingMapsAvg, + SmallVector &iteratorTypesAvg) { Location loc = op->getLoc(); constexpr int avgPoolDims = getAvgPoolNumOfDims(); @@ -1006,16 +1113,15 @@ static std::optional createAvgPoolValueCountIncludePadFalseCase( return success(); } -// Creates the average pooling operation value when the -// count_include_pad parameter is equal to true. -template -static LogicalResult createAvgPoolValueCountIncludePadTrueCase( - OpTy op, typename OpTy::Adaptor &adaptor, - ConversionPatternRewriter &rewriter, Value self, Value sumPool, - Value outputTensor, Type resultType, - SmallVectorImpl &kernelSizeIntValues, - SmallVector &indexingMapsAvg, - SmallVector &iteratorTypesAvg) { +template +LogicalResult ConvertAtenAvgPoolOp:: + createAvgPoolValueCountIncludePadTrueCase( + OpTy op, typename OpTy::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, Value self, Value sumPool, + Value outputTensor, Type resultType, + SmallVectorImpl &kernelSizeIntValues, + SmallVector &indexingMapsAvg, + SmallVector &iteratorTypesAvg) { Location loc = op->getLoc(); Type resultElementType = cast(resultType).getElementType(); @@ -1052,86 +1158,6 @@ static LogicalResult createAvgPoolValueCountIncludePadTrueCase( return success(); } -namespace { -template -class ConvertAtenAvgPoolOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) - return failure(); - - Location loc = op->getLoc(); - const TypeConverter *typeConverter = this->getTypeConverter(); - Value self = adaptor.getSelf(); - - Type inputElementType = - cast(self.getType()).getElementType(); - Type resultType = typeConverter->convertType(op.getType()); - Type resultElementType = - cast(resultType).getElementType(); - - bool ceilMode; - SmallVector kernelSizeIntValues; - SmallVector strideInts, paddingInts, dilationInts(Dim, 1); - if (failed(checkAndGetPoolingParameters(op, rewriter, typeConverter, - ceilMode, kernelSizeIntValues, - strideInts, paddingInts))) - return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); - - // Decode strideInts into strideInts and dilation - if (strideInts.size() == 2 * Dim) { - for (int i = 0; i < Dim; i++) { - dilationInts[i] = strideInts[Dim + i]; - } - for (int i = 0; i < Dim; i++) { - strideInts.pop_back(); - } - } - - // TODO: Add support for count_include_pad equal to `False`. - bool countIncludePad; - if (!matchPattern(op.getCountIncludePad(), - m_TorchConstantBool(&countIncludePad))) - return rewriter.notifyMatchFailure( - op, "count_include_pad must be a constant"); - - // `sumPool` contains the result of sumpool operation over the input. - Value sumPool, paddedInput; - SmallVector outTensorShape; - if (failed(createPoolingOp( - op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, - /*dimensionality=*/Dim, kernelSizeIntValues, strideInts, - paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType), - outTensorShape, paddedInput, sumPool))) - return rewriter.notifyMatchFailure(op, "unable to compute sumpool"); - - // Compute the average of sumPool. - Value outputTensor = rewriter.create( - loc, getAsOpFoldResult(outTensorShape), resultElementType); - SmallVector indexingMapsAvg( - 2, rewriter.getMultiDimIdentityMap(Dim + 2)); - SmallVector iteratorTypesAvg( - Dim + 2, utils::IteratorType::parallel); - - auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase( - countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor, - resultType, kernelSizeIntValues, strideInts, paddingInts, - indexingMapsAvg, iteratorTypesAvg); - if (divisorOpResult) - return *divisorOpResult; - - return createAvgPoolValueCountIncludePadTrueCase( - op, adaptor, rewriter, self, sumPool, outputTensor, resultType, - kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg); - - return success(); - } -}; -} // namespace - /* This section is for lowering adaptive pooling ops, which cannot generally be decomposed into typical pooling ops. Given an input tensor of rank (N,C,Hin) and From 68be30f50adeaef1faebad82bdd2bb7d1d7ac005 Mon Sep 17 00:00:00 2001 From: Ivan Garcia Date: Tue, 4 Mar 2025 16:07:55 -0500 Subject: [PATCH 3/3] Update comment for rebasing with main. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 846feba1b125..3c971354783a 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -909,12 +909,12 @@ Value PoolSizeCalculator::getPoolSize( b.createOrFold(location, b.getI64IntegerAttr(0)); for (int i = 0; i < NumOfDims; ++i) { - // See the link below for the PyTorch implementation where - // this is derived from: + // See the link below for the PyTorch implementation where this is + // derived from: // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 - // Dim below stands for spatial dimension. Prior to the February - // 2025 change, these variables used "height" and "width" (or - // "h" and "w") in these intermediate variables instead of "Dim". + // Dim below stands for spatial dimension. Prior to the February 2025 + // change, these variables used "height" and "width" (or "h" and "w") + // in these intermediate variables instead of "Dim". Value IndexODim = b.createOrFold(location, /*value=*/DimSizeFromSumPoolType[i]);