Skip to content

[mlir][linalg] Take artificial padding into account for pack/unpack folding. #150127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_LINALG_IR_LINALG_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
Expand Down Expand Up @@ -89,6 +90,11 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim);
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
int64_t dim);

/// Returns the outer shape in the packed domain before applying the
/// transposition.
template <typename OpTy>
SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);

} // namespace linalg
} // namespace mlir

Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,10 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
ArrayRef<int64_t> innerPermutation,
ArrayRef<int64_t> outerPermutation);

/// Returns true if it is statically known that the `sliceOp` result shape
/// is compatible with the `unPackOp`. I.e., it does not drop any tile.
bool canFoldSliceOp(tensor::ExtractSliceOp sliceOp);

/// Check if this UnPackOp is like a simple unpad operation.
/// In other words, this operation:
/// 1. drops useless dimensions (dimension of size 1), and
Expand Down
55 changes: 50 additions & 5 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4490,6 +4490,29 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
//===----------------------------------------------------------------------===//
// PackOp/UnPackOp Common
//===----------------------------------------------------------------------===//

template <typename OpTy>
SmallVector<int64_t>
getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getDestType()
: packOrUnPack.getSourceType();
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getSourceType()
: packOrUnPack.getDestType();
SmallVector<int64_t> result(
packedType.getShape().take_front(unpackedType.getRank()));
if (!packOrUnPack.getOuterDimsPerm().empty()) {
applyPermutationToVector(
result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
}
return result;
}
template SmallVector<int64_t>
getPackedOuterShapeWithoutTransposition<PackOp>(PackOp);
template SmallVector<int64_t>
getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);

// Given the (potentially) updated packed type, `newPackedTy`, generates an
// updated mixed-tile-sizes attribute. A tile size is updated only
// when:
Expand Down Expand Up @@ -5447,11 +5470,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
if (unPackOp->hasOneUse()) {
auto extractSliceUser =
dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
if (extractSliceUser &&
areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
extractSliceUser.getSourceType().getRank() ==
extractSliceUser.getResultType().getRank()) {
if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
auto newDest = rewriter.create<tensor::ExtractSliceOp>(
Expand Down Expand Up @@ -5494,6 +5513,32 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
return failure();
}

bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
// Rank-reduced folding is not supported.
if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
return false;
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
return false;
RankedTensorType unpackedType = sliceOp.getResultType();
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(*this);
for (auto [pos, tileSize] :
llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
if (unpackedType.isDynamicDim(pos))
return false;
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
return false;
if (ShapedType::isDynamic(tileSize))
return false;
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
unpackedType.getDimSize(pos);
if (paddingSize >= tileSize)
return false;
}
return true;
}

bool UnPackOp::isLikeUnPad() {
RankedTensorType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
Expand Down
38 changes: 27 additions & 11 deletions mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,31 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
return failure();

// Folding is not allowed if it introduces artificial padding. It is not
// safe to fold the ops if any dynamic dimension or tile size is present,
// because we can not infer the padding size.
RankedTensorType unpackedType = packOp.getSourceType();
SmallVector<int64_t> outerShapeWithoutTranspose =
getPackedOuterShapeWithoutTransposition(packOp);
for (auto [pos, tileSize, high] :
llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
padOp.getMixedHighPad())) {
if (unpackedType.isDynamicDim(pos))
return failure();
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
return failure();
if (ShapedType::isDynamic(tileSize))
return failure();
std::optional<int64_t> cstHigh = getConstantIntValue(high);
if (!cstHigh)
return failure();
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
unpackedType.getDimSize(pos);
// Do not fold the op if it requires artificial padding.
if (paddingSize + cstHigh.value() >= tileSize)
return failure();
}

rewriter.replaceOpWithNewOp<PackOp>(
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
packOp.getMixedTiles(), constantPaddingValue,
Expand Down Expand Up @@ -251,17 +276,8 @@ struct FoldUnpackWithExtractSliceOp
if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
return failure();

if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
return rewriter.notifyMatchFailure(
sliceOp, "rank-reduced folding is not supported");
}

// Check all offsets are zeros, and all strides are ones.
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
return rewriter.notifyMatchFailure(
sliceOp, "expects offsets to be 0s and strides to be 1s");
}
if (!unpackOp.canFoldSliceOp(sliceOp))
return failure();

// Create a new empty output tensor.
Type elementType = unpackOp.getDestType().getElementType();
Expand Down
37 changes: 27 additions & 10 deletions mlir/test/Dialect/Linalg/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1890,30 +1890,47 @@ func.func @fold_cast_unpack_dynamic_tile_size(
//===----------------------------------------------------------------------===//

func.func @fold_extract_slice_into_unpack(
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
) -> tensor<28x28x?xf32> {
%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
) -> tensor<28x28x10xf32> {
%unpack = linalg.unpack %src
outer_dims_perm = [0, 1, 2]
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
%extracted_slice = tensor.extract_slice %unpack
[0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
return %extracted_slice : tensor<28x28x?xf32>
[0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
return %extracted_slice : tensor<28x28x10xf32>
}

// CHECK-LABEL: func @fold_extract_slice_into_unpack
// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
// CHECK-SAME: %[[SIZE:.+]]: index
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[SIZE:[a-zA-Z0-9]+]]
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
// CHECK-SAME: [0, 0, 0] [28, 28, 10] [1, 1, 1]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
// CHECK-SAME: into %[[DEST_SLICE]]
// CHECK: return %[[UNPACK]]

// -----

func.func @no_fold_extract_slice_into_unpack_dynamic(
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
) -> tensor<28x28x?xf32> {
%unpack = linalg.unpack %src
outer_dims_perm = [0, 1, 2]
inner_dims_pos = [1, 2]
inner_tiles = [16, 16]
into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
%extracted_slice = tensor.extract_slice %unpack
[0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
return %extracted_slice : tensor<28x28x?xf32>
}
// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
// CHECK: linalg.unpack
// CHECK: tensor.extract_slice

// -----

func.func @no_fold_extract_slice_into_unpack_rank_reducing(
%src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
) -> tensor<28xf32> {
Expand Down
60 changes: 44 additions & 16 deletions mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL

func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
func.func @fold_unpack_slice(%arg0 : tensor<2082x1x8x32xf32>) -> tensor<16649x16xf32> {
%empty = tensor.empty() : tensor<16656x16xf32>
%0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
: tensor<2082x1x8x32xf32> -> tensor<16656x16xf32>
%1 = tensor.extract_slice %0[0, 0] [16649, 16] [1, 1] : tensor<16656x16xf32> to tensor<16649x16xf32>
return %1 : tensor<16649x16xf32>
}
// CHECK-LABEL: func @fold_unpack_slice(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<16649x16xf32>
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
// CHECK-SAME: into %[[INIT]]
// CHECK: return %[[UNPACK]]

// -----

func.func @nofold_dynamic_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
%0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
: tensor<?x?x8x4xf32> -> tensor<?x?xf32>
%1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK: func @fold_unpack_slice(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x8x4xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK: %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32>
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4]
// CHECK-SAME: into %[[INIT]]
// CHECK: return %[[UNPACK]]
// CHECK-LABEL: func @nofold_dynamic_unpack_slice(
// CHECK: linalg.unpack
// CHECK: tensor.extract_slice

// -----

Expand Down Expand Up @@ -59,13 +69,13 @@ func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 :

// -----

func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
func.func @pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%padded = tensor.pad %src low[0, 0] high[15, 0] {
%padded = tensor.pad %src low[0, 0] high[7, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
} : tensor<16641x16xf32> to tensor<16656x16xf32>
} : tensor<16649x16xf32> to tensor<16656x16xf32>
%empty = tensor.empty() : tensor<2082x1x8x32xf32>
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
Expand All @@ -81,10 +91,10 @@ func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {

// -----

func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
func.func @nofold_pad_pack_artificial_padding(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
%padded = tensor.pad %src low[0, 0] high[15, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
} : tensor<16641x16xf32> to tensor<16656x16xf32>
Expand All @@ -93,7 +103,25 @@ func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
return %pack : tensor<2082x1x8x32xf32>
}
// CHECK-LABEL: func.func @nofold_pad_pack
// CHECK-LABLE: func.func @nofold_pad_pack_artificial_padding(
// CHECK: tensor.pad
// CHECK: linalg.pack

// -----

func.func @nofold_pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%padded = tensor.pad %src nofold low[0, 0] high[7, 0] {
^bb0(%arg0: index, %arg1: index):
tensor.yield %cst : f32
} : tensor<16649x16xf32> to tensor<16656x16xf32>
%empty = tensor.empty() : tensor<2082x1x8x32xf32>
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
return %pack : tensor<2082x1x8x32xf32>
}
// CHECK-LABEL: func.func @nofold_pad_pack(
// CHECK: tensor.pad
// CHECK: linalg.pack

Expand Down
Loading