@@ -411,9 +411,64 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
411411 }
412412};
413413
414+ // / Pattern to fold a tensor.cast into a tosa.transpose operation.
415+ // /
416+ // / This pattern pushes tensor.cast operations past transpose when the cast
417+ // / goes from a more static type to a less static (more dynamic) type. This
418+ // / allows the transpose to operate on more refined types, enabling better
419+ // / optimizations and type inference in downstream operations.
420+ // /
421+ // / The pattern adds a cast back to the original result type for compatibility
422+ // / with existing users.
423+ // /
424+ // / Example:
425+ // / ```
426+ // / %cast = tensor.cast %input : tensor<6x256x40xi8> to tensor<6x256x?xi8>
427+ // / %transpose = tosa.transpose %cast {perms = [0, 2, 1]}
428+ // / : (tensor<6x256x?xi8>) -> tensor<6x?x256xi8>
429+ // / ```
430+ // / is canonicalized to:
431+ // / ```
432+ // / %transpose = tosa.transpose %input {perms = [0, 2, 1]}
433+ // / : (tensor<6x256x40xi8>) -> tensor<6x40x256xi8>
434+ // / %cast = tensor.cast %transpose
435+ // / : tensor<6x40x256xi8> to tensor<6x?x256xi8>
436+ // / ```
437+ struct TransposeOpCastFolder : public OpRewritePattern <tosa::TransposeOp> {
438+ using OpRewritePattern::OpRewritePattern;
439+
440+ LogicalResult matchAndRewrite (tosa::TransposeOp transposeOp,
441+ PatternRewriter &rewriter) const override {
442+ if (!tensor::hasFoldableTensorCastOperand (transposeOp))
443+ return rewriter.notifyMatchFailure (transposeOp,
444+ " no foldable cast operand" );
445+
446+ auto castOp = cast<tensor::CastOp>(transposeOp.getInput1 ().getDefiningOp ());
447+ auto srcType = cast<RankedTensorType>(castOp.getSource ().getType ());
448+ auto oldResultType = cast<RankedTensorType>(transposeOp.getType ());
449+
450+ ArrayRef<int32_t > perms = transposeOp.getPerms ();
451+ assert (perms.size () == static_cast <size_t >(srcType.getRank ()) &&
452+ " permutation size must match source rank" );
453+ SmallVector<int64_t > newShape;
454+ newShape.reserve (srcType.getRank ());
455+ for (int32_t perm : perms)
456+ newShape.push_back (srcType.getDimSize (perm));
457+ auto newResultType = RankedTensorType::get (
458+ newShape, srcType.getElementType (), srcType.getEncoding ());
459+ auto newTransposeOp = tosa::TransposeOp::create (
460+ rewriter, transposeOp.getLoc (), newResultType, castOp.getSource (), perms);
461+
462+ rewriter.replaceOpWithNewOp <tensor::CastOp>(
463+ transposeOp, oldResultType, newTransposeOp);
464+ return success ();
465+ }
466+ };
467+
414468void TransposeOp::getCanonicalizationPatterns (RewritePatternSet &results,
415469 MLIRContext *context) {
416- results.add <ConsolidateTransposeOptimization, TransposeIsReshape>(context);
470+ results.add <ConsolidateTransposeOptimization, TransposeIsReshape,
471+ TransposeOpCastFolder>(context);
417472}
418473
419474struct ClampIsNoOp : public OpRewritePattern <tosa::ClampOp> {
0 commit comments