diff --git a/externals/llvm-project b/externals/llvm-project index a854c266b984..5d6d982df61d 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit a854c266b98468ad4479a7d3c56a3fa76437e30d +Subproject commit 5d6d982df61d16b6d498e6d59dd91c059679d3d8 diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 90b5b2af77a8..6ce86d468894 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -18,6 +18,7 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include using namespace mlir; using namespace mlir::torch; @@ -829,6 +830,229 @@ class ConvertAtenMaxUnpool3dOp final }; } // namespace +namespace { +// The following structures and the getNumOfDims method +// are used to get the number of dimensions from the +// average pooling type at compile time. +template struct AtenAvgPoolTypeNumOfDims { + static constexpr int getNumOfDims() { return -1; } +}; +template <> struct AtenAvgPoolTypeNumOfDims { + static constexpr int getNumOfDims() { return 1; } +}; +template <> struct AtenAvgPoolTypeNumOfDims { + static constexpr int getNumOfDims() { return 2; } +}; +template <> struct AtenAvgPoolTypeNumOfDims { + static constexpr int getNumOfDims() { return 3; } +}; +template constexpr int getAvgPoolNumOfDims() { + return AtenAvgPoolTypeNumOfDims::getNumOfDims(); +} +} // namespace + +namespace { +// This is a helper class to create the pooling size value +// used in the divisor of the average pooling operator. +template class PoolSizeCalculator { +public: + PoolSizeCalculator(Value self, Value sumPool, + ConversionPatternRewriter &rewriter, Location loc); + + // The algorithm for computing the divisor with + // count_include_pad equal is mainly based on pytorch + // implementation. The following code is comment + // with pytorch code. + // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 + // Dim below stands for spatial dimension. It replaces the + // height and width labels in variables. + Value getPoolSize(OpBuilder &b, SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts); + +private: + int64_t DimSizeFromSumPoolType[NumOfDims]; + Value InputSpatialDimValues[NumOfDims]; + Location location; +}; + +} // namespace + +template +PoolSizeCalculator::PoolSizeCalculator( + Value self, Value sumPool, ConversionPatternRewriter &rewriter, + Location loc) + : location(loc) { + auto selfType = cast(self.getType()); + const int64_t selfRank = selfType.getRank(); + RankedTensorType sumPoolType = cast(sumPool.getType()); + const int64_t rank = sumPoolType.getRank(); + + // Store dimensions in this order: + // 0 => width, 1 => height, 2 => depth + for (int i = 0; i < NumOfDims; ++i) { + int64_t DimSizeFromSelfType = toPositiveDim(-(i + 1), selfRank); + InputSpatialDimValues[i] = + getDimOp(rewriter, location, self, DimSizeFromSelfType); + DimSizeFromSumPoolType[i] = toPositiveDim(-(i + 1), rank); + } +} + +template +Value PoolSizeCalculator::getPoolSize( + OpBuilder &b, SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts) { + Value poolSize; + + Value cstZero = + b.createOrFold(location, b.getI64IntegerAttr(0)); + + for (int i = 0; i < NumOfDims; ++i) { + // See the link below for the PyTorch implementation where + // this is derived from: + // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 + // Dim below stands for spatial dimension. Prior to the February + // 2025 change, these variables used "height" and "width" (or + // "h" and "w") in these intermediate variables instead of "Dim". + Value IndexODim = + b.createOrFold(location, + /*value=*/DimSizeFromSumPoolType[i]); + Value ODim = castIndexToInt64(b, location, IndexODim); + Value DDim = b.createOrFold( + location, b.getI64IntegerAttr(strideInts[i])); + Value PadDim = b.createOrFold( + location, b.getI64IntegerAttr(paddingInts[i])); + Value ODimDDim = b.createOrFold(location, ODim, DDim); + Value IDim0 = b.createOrFold(location, ODimDDim, PadDim); + Value IDim = castIndexToInt64(b, location, InputSpatialDimValues[i]); + Value IDim0KDim = + b.createOrFold(location, IDim0, kernelSizeIntValues[i]); + Value IDimPadDim = b.createOrFold(location, IDim, PadDim); + Value IDim1 = + b.createOrFold(location, IDim0KDim, IDimPadDim); + + Value IDim0Clamped = + b.createOrFold(location, IDim0, cstZero); + Value IDim1Clamped = b.createOrFold(location, IDim1, IDim); + Value IDim1_IDim0_Clamped = + b.createOrFold(location, IDim1Clamped, IDim0Clamped); + if (i == 0) { + poolSize = IDim1_IDim0_Clamped; + } else { + poolSize = b.createOrFold(location, poolSize, + IDim1_IDim0_Clamped); + } + } + return poolSize; +} + +// Creates the average pooling operation value when the +// count_include_pad parameter is equal to false. +template +static std::optional createAvgPoolValueCountIncludePadFalseCase( + bool countIncludePad, OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter, Value self, Value sumPool, + Value outputTensor, Type resultType, + SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, + SmallVector &indexingMapsAvg, + SmallVector &iteratorTypesAvg) { + Location loc = op->getLoc(); + + constexpr int avgPoolDims = getAvgPoolNumOfDims(); + + bool noPadding = llvm::all_of(paddingInts, [](int64_t p) { return p == 0; }); + if (countIncludePad || noPadding) { + // These cases are not handled here. + return std::nullopt; + } + if (avgPoolDims < 1) { + return rewriter.notifyMatchFailure( + op, "Unexpected type. Only expected AtenAvgPool1dOp, AtenAvgPool2dOp, " + "and AtenAvgPool3dOp."); + } + + Type resultElementType = cast(resultType).getElementType(); + + PoolSizeCalculator poolSizeCalculator(self, sumPool, rewriter, + loc); + Value avgPool = + rewriter + .create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + auto poolSize = poolSizeCalculator.getPoolSize( + b, kernelSizeIntValues, strideInts, paddingInts); + // AtenAvgPool2/3dOp has an optional divisor_override + // attribute while AtenAvgPool1dOp does not. + if constexpr (avgPoolDims > 1) { + if (!isa(op.getDivisorOverride().getType())) + poolSize = adaptor.getDivisorOverride(); + } + Value divisor = + convertScalarToDtype(b, loc, poolSize, resultElementType); + Value avg; + if (isa(resultElementType)) + avg = b.createOrFold(loc, args[0], divisor); + else if (isa(resultElementType)) + avg = b.createOrFold(loc, args[0], divisor); + b.createOrFold(loc, avg); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, resultType, avgPool); + return success(); +} + +// Creates the average pooling operation value when the +// count_include_pad parameter is equal to true. +template +static LogicalResult createAvgPoolValueCountIncludePadTrueCase( + OpTy op, typename OpTy::Adaptor &adaptor, + ConversionPatternRewriter &rewriter, Value self, Value sumPool, + Value outputTensor, Type resultType, + SmallVectorImpl &kernelSizeIntValues, + SmallVector &indexingMapsAvg, + SmallVector &iteratorTypesAvg) { + Location loc = op->getLoc(); + + Type resultElementType = cast(resultType).getElementType(); + + Value divisor = kernelSizeIntValues[0]; + for (uint32_t i = 1; i < kernelSizeIntValues.size(); ++i) { + divisor = rewriter.createOrFold(loc, divisor, + kernelSizeIntValues[i]); + } + + // Only average pooling 2D/3D have optional divisor override. + if constexpr (!std::is_same()) { + divisor = isa(op.getDivisorOverride().getType()) + ? divisor + : adaptor.getDivisorOverride(); + } + divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); + + Value avgPool = + rewriter + .create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value avg; + if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + else if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + b.create(loc, avg); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, resultType, avgPool); + return success(); +} + namespace { template class ConvertAtenAvgPoolOp : public OpConversionPattern { @@ -892,159 +1116,18 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { 2, rewriter.getMultiDimIdentityMap(Dim + 2)); SmallVector iteratorTypesAvg( Dim + 2, utils::IteratorType::parallel); - Value avgPool; - Value divisor; - // Case1: AtenAvgPool1d/2dOp with countIncludePad=false support. - if constexpr (std::is_same()) { - auto selfType = cast(self.getType()); - const int64_t selfRank = selfType.getRank(); - int64_t wDim = toPositiveDim(-1, selfRank); - int64_t hDim = toPositiveDim(-2, selfRank); - Value inputHeight = getDimOp(rewriter, loc, self, hDim); - Value inputWidth = getDimOp(rewriter, loc, self, wDim); - RankedTensorType sumPoolType = cast(sumPool.getType()); - const int64_t rank = sumPoolType.getRank(); - int dimH = toPositiveDim(-2, rank); - int dimW = toPositiveDim(-1, rank); - avgPool = - rewriter - .create( - loc, outputTensor.getType(), sumPool, outputTensor, - /*indexingMaps=*/indexingMapsAvg, - /*iteratorTypes=*/iteratorTypesAvg, - [&](OpBuilder &b, Location loc, ValueRange args) { - // The algorithm for computing the divisor with - // count_include_pad is manily based on pytorch - // implementation. The following code is comment - // with pytorch code. - // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 - Value indexOh = - b.create(loc, /*value=*/dimH); - Value oh = castIndexToInt64(b, loc, indexOh); - Value indexOw = - b.create(loc, /*value=*/dimW); - Value ow = castIndexToInt64(b, loc, indexOw); - - // int64_t ih0 = oh * dH - padH; - Value dH = rewriter.create( - loc, rewriter.getI64IntegerAttr(strideInts[0])); - Value padH = rewriter.create( - loc, rewriter.getI64IntegerAttr(paddingInts[0])); - Value ohDH = b.create(loc, oh, dH); - Value ih0 = b.create(loc, ohDH, padH); - // int64_t iw0 = ow * dW - padW; - Value dW = rewriter.create( - loc, rewriter.getI64IntegerAttr(strideInts[1])); - Value padW = rewriter.create( - loc, rewriter.getI64IntegerAttr(paddingInts[1])); - Value owDW = b.create(loc, ow, dW); - Value iw0 = b.create(loc, owDW, padW); - // int64_t ih1 = std::min(ih0 + kH, input_height + padH); - Value ih = castIndexToInt64(b, loc, inputHeight); - Value ih0KH = b.create( - loc, ih0, kernelSizeIntValues[0]); - Value ihPadH = b.create(loc, ih, padH); - Value ih1 = b.create(loc, ih0KH, ihPadH); - // int64_t iw1 = std::min(iw0 + kW, input_width + padW); - Value iw = castIndexToInt64(b, loc, inputWidth); - Value iw0KW = b.create( - loc, iw0, kernelSizeIntValues[1]); - Value iwPadW = b.create(loc, iw, padW); - Value iw1 = b.create(loc, iw0KW, iwPadW); - // int64_t pool_size = (ih1 - ih0) * (iw1 - iw0); - Value ih1Ih0 = b.create(loc, ih1, ih0); - Value iw1Iw0 = b.create(loc, iw1, iw0); - Value poolSize = - b.create(loc, ih1Ih0, iw1Iw0); - // ih0 = std::max(ih0, 0); - Value cstZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value ih0Clamped = - b.create(loc, ih0, cstZero); - // iw0 = std::max(iw0, 0); - Value iw0Clamped = - b.create(loc, iw0, cstZero); - // ih1 = std::min(ih1, input_height); - Value ih1Clamped = b.create(loc, ih1, ih); - // iw1 = std::min(iw1, input_width); - Value iw1Clamped = b.create(loc, iw1, iw); - // if (divisor_override.has_value()) { - // divisor = divisor_override.value(); - // } else { - // if(count_include_pad) { - // divisor = pool_size; - // } else { - // divisor = (ih1 - ih0) * (iw1 - iw0); - // } - // } - if (countIncludePad) { - divisor = convertScalarToDtype(b, loc, poolSize, - resultElementType); - } else { - Value ih1_ih0 = - b.create(loc, ih1Clamped, ih0Clamped); - Value iw1_iw0 = - b.create(loc, iw1Clamped, iw0Clamped); - divisor = b.create(loc, ih1_ih0, iw1_iw0); - } - // AtenAvgPool2/3dOp has an optional divisor_override - // attribute while AtenAvgPool1dOp does not. - if constexpr (std::is_same()) { - if (!isa( - op.getDivisorOverride().getType())) - divisor = adaptor.getDivisorOverride(); - } - - divisor = convertScalarToDtype(b, loc, divisor, - resultElementType); - Value avg; - if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - else if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - b.create(loc, avg); - }) - .getResult(0); - rewriter.replaceOpWithNewOp(op, resultType, avgPool); - return success(); - } - // TODO: Add support for count_include_pad equal to `False` in - // AtenAvgPool1/3dOp. - if (!countIncludePad && - !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) { - return rewriter.notifyMatchFailure( - op, "unimplemented: count_include_pad is expected to be true for " - "AtenAvgPool3dOp"); - } + auto divisorOpResult = createAvgPoolValueCountIncludePadFalseCase( + countIncludePad, op, adaptor, rewriter, self, sumPool, outputTensor, + resultType, kernelSizeIntValues, strideInts, paddingInts, + indexingMapsAvg, iteratorTypesAvg); + if (divisorOpResult) + return *divisorOpResult; + + return createAvgPoolValueCountIncludePadTrueCase( + op, adaptor, rewriter, self, sumPool, outputTensor, resultType, + kernelSizeIntValues, indexingMapsAvg, iteratorTypesAvg); - // Case2: AtenAvgPool1/3dOp without count_include_pad equal to `False`. - divisor = kernelSizeIntValues[0]; - for (uint32_t i = 1; i < kernelSizeIntValues.size(); i++) { - divisor = - rewriter.create(loc, divisor, kernelSizeIntValues[i]); - } - if constexpr (!std::is_same()) { - divisor = isa(op.getDivisorOverride().getType()) - ? divisor - : adaptor.getDivisorOverride(); - } - divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); - avgPool = rewriter - .create( - loc, outputTensor.getType(), sumPool, outputTensor, - /*indexingMaps=*/indexingMapsAvg, - /*iteratorTypes=*/iteratorTypesAvg, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value avg; - if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - else if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - b.create(loc, avg); - }) - .getResult(0); - rewriter.replaceOpWithNewOp(op, resultType, avgPool); return success(); } }; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3dd78cc011f6..1d7dd30b56c9 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3393,6 +3393,8 @@ "AtenKthvalueKeepDimModule_basic", "AtenKthvalueModule_basic", "AvgPool3dStaticModule_basic", + "AvgPool3dCountIncludePadFalse_basic", + "AvgPool3dCountIncludePadFalseWithoutPadding_basic", "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", "Conv_Transpose2dStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index e2eaa4cfd0fe..ce7d2e2bc42e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1460,6 +1460,66 @@ def AvgPool3dStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 4, 4, 4, low=-1)) +class AvgPool3dCountIncludePadFalse(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap3d = torch.nn.AvgPool3d( + kernel_size=[3, 3, 3], + stride=[1, 1, 1], + padding=[1, 1, 1], + ceil_mode=False, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([3, 3, 12, 12, 12], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap3d(x) + + +@register_test_case(module_factory=lambda: AvgPool3dCountIncludePadFalse()) +def AvgPool3dCountIncludePadFalse_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3, 12, 12, 12, low=-1)) + + +class AvgPool3dCountIncludePadFalseWithoutPadding(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap3d = torch.nn.AvgPool3d( + kernel_size=[3, 3, 3], + stride=[1, 1, 1], + padding=[0, 0, 0], + ceil_mode=False, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([3, 3, 12, 12, 12], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap3d(x) + + +@register_test_case( + module_factory=lambda: AvgPool3dCountIncludePadFalseWithoutPadding() +) +def AvgPool3dCountIncludePadFalseWithoutPadding_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3, 12, 12, 12, low=-1)) + + # ============================================================================== @@ -1532,6 +1592,54 @@ def AvgPool1dStaticModule_basic(module, tu: TestUtils): module.forward(tu.randint(2, 4, 20, high=100)) +class AvgPool1dCountIncludePadFalseWithoutPadding(torch.nn.Module): + def __init__(self): + super().__init__() + self.ap1d = torch.nn.AvgPool1d( + kernel_size=3, stride=1, padding=0, ceil_mode=False, count_include_pad=False + ) + + @export + @annotate_args( + [ + None, + ([3, 4, 20], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap1d(x) + + +@register_test_case( + module_factory=lambda: AvgPool1dCountIncludePadFalseWithoutPadding() +) +def AvgPool1dCountIncludePadFalseWithoutPadding_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 20)) + + +class AvgPool1dCountIncludePadFalse(torch.nn.Module): + def __init__(self): + super().__init__() + self.ap1d = torch.nn.AvgPool1d( + kernel_size=3, stride=1, padding=1, ceil_mode=False, count_include_pad=False + ) + + @export + @annotate_args( + [ + None, + ([3, 4, 20], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap1d(x) + + +@register_test_case(module_factory=lambda: AvgPool1dCountIncludePadFalse()) +def AvgPool1dCountIncludePadFalse_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 20)) + + # ============================================================================== diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 558c50c4f08f..53faa1d37d4f 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -95,3 +95,133 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch. // CHECK: } -> tensor return %4 : !torch.vtensor<[?,?,?,?,?],f32> } + +// CHECK-LABEL: func @forward_avg_pool2d +func.func @forward_avg_pool2d(%arg0: !torch.vtensor<[1,3,64,56],f32>) -> !torch.vtensor<[1,3, 61,27],f32> { + // CHECK: linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<[1, 2]> : vector<2xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x3x64x58xf32>, tensor<4x5xf32>) outs(%[[OUT1:.*]] : tensor<1x3x61x27xf32>) -> tensor<1x3x61x27xf32> + // CHECK: linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x61x27xf32>) outs(%[[OUT2:.*]] : tensor<1x3x61x27xf32>) + // CHECK-NEXT: ^bb0(%[[BIIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-NEXT: %[[TMP1:.*]] = arith.divf %[[BIIN1:.*]], %[[CONST1:.*]] : f32 + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x3x61x27xf32> + %none = torch.constant.none + %false = torch.constant.bool false + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[1,3,64,56],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,61,27],f32> + return %3 : !torch.vtensor<[1,3,61,27],f32> +} + +// CHECK-LABEL: func @forward_avg_pool2d_countincludepad_false +func.func @forward_avg_pool2d_countincludepad_false(%arg0: !torch.vtensor<[1,3,64,56],f32>) -> !torch.vtensor<[1,3, 61,27],f32> { + // CHECK: linalg.pooling_nchw_sum {dilations = dense<1> : vector<2xi64>, strides = dense<[1, 2]> : vector<2xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x3x64x58xf32>, tensor<4x5xf32>) outs(%[[OUT1:.*]] : tensor<1x3x61x27xf32>) -> tensor<1x3x61x27xf32> + // CHECK: linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x61x27xf32>) outs(%[[OUT2:.*]] : tensor<1x3x61x27xf32>) + // CHECK-NEXT: ^bb0(%[[BIIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-COUNT-4: arith.minsi + // CHECK-COUNT-1: arith.divf + // CHECK: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x3x61x27xf32> + %none = torch.constant.none + %false = torch.constant.bool false + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %false, %none : !torch.vtensor<[1,3,64,56],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,61,27],f32> + return %3 : !torch.vtensor<[1,3,61,27],f32> +} + +// CHECK-LABEL: func @forward_avg_pool3d +func.func @forward_avg_pool3d(%arg0: !torch.vtensor<[1,3,7,64,56],f32>) -> !torch.vtensor<[1,3,4,31,54],f32> { + // CHECK: linalg.pooling_ndhwc_sum {dilations = dense<1> : vector<3xi64>, strides = dense<[1, 2, 1]> : vector<3xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x7x66x58x3xf32>, tensor<4x5x5xf32>) outs(%[[OUT1:.*]] : tensor<1x4x31x54x3xf32>) -> tensor<1x4x31x54x3xf32> + // CHECK: linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x4x31x54xf32>) outs(%[[OUT2:.*]] : tensor<1x3x4x31x54xf32>) + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-NEXT: %[[TMP1:.*]] = arith.divf %[[BIN1:.*]], %[[CONST1:.*]] : f32 + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x3x4x31x54xf32> + %none = torch.constant.none + %false = torch.constant.bool false + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %0 = torch.prim.ListConstruct %int4, %int5, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool3d %arg0, %0, %1, %2, %false, %true, %none : !torch.vtensor<[1,3,7,64,56],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,4,31,54],f32> + return %3 : !torch.vtensor<[1,3,4,31,54],f32> +} + +// CHECK-LABEL: func @forward_avg_pool3dd_countincludepad_false +func.func @forward_avg_pool3dd_countincludepad_false(%arg0: !torch.vtensor<[1,3,7,64,56],f32>) -> !torch.vtensor<[1,3,4,31,54],f32> { + // CHECK: linalg.pooling_ndhwc_sum {dilations = dense<1> : vector<3xi64>, strides = dense<[1, 2, 1]> : vector<3xi64>} ins(%[[IN1:.*]], %[[KSIZE1:.*]] : tensor<1x7x66x58x3xf32>, tensor<4x5x5xf32>) outs(%[[OUT1:.*]] : tensor<1x4x31x54x3xf32>) -> tensor<1x4x31x54x3xf32> + // CHECK: linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[IN2:.*]] : tensor<1x3x4x31x54xf32>) outs(%[[OUT2:.*]] : tensor<1x3x4x31x54xf32>) + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-COUNT-6: arith.minsi + // CHECK-COUNT-1: arith.divf + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x3x4x31x54xf32> + %none = torch.constant.none + %false = torch.constant.bool false + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %0 = torch.prim.ListConstruct %int4, %int5, %int5 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int2, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int1, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool3d %arg0, %0, %1, %2, %false, %false, %none : !torch.vtensor<[1,3,7,64,56],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,4,31,54],f32> + return %3 : !torch.vtensor<[1,3,4,31,54],f32> +} + +// CHECK-LABEL: func @forward_avg_pool1d +func.func @forward_avg_pool1d(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,12],f32> { + // CHECK: linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%[[IN1:.*]], %[[IN2:.*]] : tensor<1x512x12xf32>, tensor<1xf32>) outs(%[[OUT1:.*]] : tensor<1x512x12xf32>) -> tensor<1x512x12xf32> + // CHECK: linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[IN3:.*]] : tensor<1x512x12xf32>) outs(%[[OUT2:.*]] : tensor<1x512x12xf32> + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-NEXT: %[[TMP1:.*]] = arith.divf %[[BIN1:.*]], %[[CONST1:.*]] : f32 + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x512x12xf32> + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %true = torch.constant.bool true + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %true : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,12],f32> + return %3 : !torch.vtensor<[1,512,12],f32> +} + +// CHECK-LABEL: func @forward_avg_pool1d_countincludepad_false +func.func @forward_avg_pool1d_countincludepad_false(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,12],f32> { + // CHECK: linalg.pooling_ncw_sum {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%[[IN1:.*]], %[[IN2:.*]] : tensor<1x512x12xf32>, tensor<1xf32>) outs(%[[OUT1:.*]] : tensor<1x512x12xf32>) -> tensor<1x512x12xf32> + // CHECK: linalg.generic {indexing_maps = [#map5, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[IN3:.*]] : tensor<1x512x12xf32>) outs(%[[OUT2:.*]] : tensor<1x512x12xf32> + // CHECK-NEXT: ^bb0(%[[BIN1:.*]]: f32, %[[BOUT1:.*]]: f32): + // CHECK-COUNT-2: arith.minsi + // CHECK-COUNT-1: arith.divf + // CHECK-NEXT: linalg.yield %[[TMP1:.*]] : f32 + // CHECK-NEXT: } -> tensor<1x512x12xf32> + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %false : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,12],f32> + return %3 : !torch.vtensor<[1,512,12],f32> +}