Skip to content

Commit 8f02880

Browse files
author
Tomer Solomon
committed
[mlir][tosa] Fold tensor.cast into tosa.transpose
Push relaxing tensor.cast operations past tosa.transpose when the cast goes from a more static type to a more dynamic one. This lets the transpose operate on the more specific input type and preserves shape information. A cast back to the original result type is inserted for compatibility with existing users.
1 parent 22257e8 commit 8f02880

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,9 +411,64 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
411411
}
412412
};
413413

414+
/// Pattern to fold a tensor.cast into a tosa.transpose operation.
415+
///
416+
/// This pattern pushes tensor.cast operations past transpose when the cast
417+
/// goes from a more static type to a less static (more dynamic) type. This
418+
/// allows the transpose to operate on more refined types, enabling better
419+
/// optimizations and type inference in downstream operations.
420+
///
421+
/// The pattern adds a cast back to the original result type for compatibility
422+
/// with existing users.
423+
///
424+
/// Example:
425+
/// ```
426+
/// %cast = tensor.cast %input : tensor<6x256x40xi8> to tensor<6x256x?xi8>
427+
/// %transpose = tosa.transpose %cast {perms = [0, 2, 1]}
428+
/// : (tensor<6x256x?xi8>) -> tensor<6x?x256xi8>
429+
/// ```
430+
/// is canonicalized to:
431+
/// ```
432+
/// %transpose = tosa.transpose %input {perms = [0, 2, 1]}
433+
/// : (tensor<6x256x40xi8>) -> tensor<6x40x256xi8>
434+
/// %cast = tensor.cast %transpose
435+
/// : tensor<6x40x256xi8> to tensor<6x?x256xi8>
436+
/// ```
437+
struct TransposeOpCastFolder : public OpRewritePattern<tosa::TransposeOp> {
438+
using OpRewritePattern::OpRewritePattern;
439+
440+
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
441+
PatternRewriter &rewriter) const override {
442+
if (!tensor::hasFoldableTensorCastOperand(transposeOp))
443+
return rewriter.notifyMatchFailure(transposeOp,
444+
"no foldable cast operand");
445+
446+
auto castOp = cast<tensor::CastOp>(transposeOp.getInput1().getDefiningOp());
447+
auto srcType = cast<RankedTensorType>(castOp.getSource().getType());
448+
auto oldResultType = cast<RankedTensorType>(transposeOp.getType());
449+
450+
ArrayRef<int32_t> perms = transposeOp.getPerms();
451+
assert(perms.size() == static_cast<size_t>(srcType.getRank()) &&
452+
"permutation size must match source rank");
453+
SmallVector<int64_t> newShape;
454+
newShape.reserve(srcType.getRank());
455+
for (int32_t perm : perms)
456+
newShape.push_back(srcType.getDimSize(perm));
457+
auto newResultType = RankedTensorType::get(
458+
newShape, srcType.getElementType(), srcType.getEncoding());
459+
auto newTransposeOp = tosa::TransposeOp::create(
460+
rewriter, transposeOp.getLoc(), newResultType, castOp.getSource(), perms);
461+
462+
rewriter.replaceOpWithNewOp<tensor::CastOp>(
463+
transposeOp, oldResultType, newTransposeOp);
464+
return success();
465+
}
466+
};
467+
414468
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
415469
MLIRContext *context) {
416-
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
470+
results.add<ConsolidateTransposeOptimization, TransposeIsReshape,
471+
TransposeOpCastFolder>(context);
417472
}
418473

419474
struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,38 @@ func.func @concat_fold_cast(%arg0: tensor<?x1xf32>) -> tensor<?x?xf32> {
411411

412412
// -----
413413

414+
// CHECK-LABEL: @fold_relaxing_cast_into_transpose
415+
func.func @fold_relaxing_cast_into_transpose(%arg0: tensor<6x256x40xi8>) -> tensor<6x?x256xi8> {
416+
// CHECK: %[[VAL_1:.*]] = tosa.transpose %arg0 {perms = array<i32: 0, 2, 1>} : (tensor<6x256x40xi8>) -> tensor<6x40x256xi8>
417+
// CHECK: tensor.cast %[[VAL_1]] : tensor<6x40x256xi8> to tensor<6x?x256xi8>
418+
%0 = tensor.cast %arg0 : tensor<6x256x40xi8> to tensor<6x256x?xi8>
419+
%1 = tosa.transpose %0 {perms = array<i32: 0, 2, 1>} : (tensor<6x256x?xi8>) -> tensor<6x?x256xi8>
420+
return %1 : tensor<6x?x256xi8>
421+
}
422+
423+
// -----
424+
425+
// CHECK-LABEL: func.func @no_fold_refining_cast_into_transpose(
426+
func.func @no_fold_refining_cast_into_transpose(%arg0: tensor<?x?x256xf32>) -> tensor<?x256x8xf32> {
427+
// CHECK: %[[VAL_1:.*]] = tensor.cast %arg0 : tensor<?x?x256xf32> to tensor<?x8x256xf32>
428+
// CHECK: tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 1>} : (tensor<?x8x256xf32>) -> tensor<?x256x8xf32>
429+
%0 = tensor.cast %arg0 : tensor<?x?x256xf32> to tensor<?x8x256xf32>
430+
%1 = tosa.transpose %0 {perms = array<i32: 0, 2, 1>} : (tensor<?x8x256xf32>) -> tensor<?x256x8xf32>
431+
return %1 : tensor<?x256x8xf32>
432+
}
433+
434+
// -----
435+
436+
// CHECK-LABEL: @elide_identity_cast_before_transpose
437+
func.func @elide_identity_cast_before_transpose(%arg0: tensor<?x8x256xf32>) -> tensor<?x256x8xf32> {
438+
// CHECK-NOT: tensor.cast
439+
%0 = tensor.cast %arg0 : tensor<?x8x256xf32> to tensor<?x8x256xf32>
440+
%1 = tosa.transpose %0 {perms = array<i32: 0, 2, 1>} : (tensor<?x8x256xf32>) -> tensor<?x256x8xf32>
441+
return %1 : tensor<?x256x8xf32>
442+
}
443+
444+
// -----
445+
414446
// CHECK-LABEL: @conv2d_stride_2
415447
func.func @conv2d_stride_2(%arg0: tensor<4x11x11x2xf32>) -> tensor<4x6x6x3xf32> {
416448
// CHECK: tosa.conv2d

0 commit comments

Comments
 (0)