|
12 | 12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | 13 | #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
14 | 14 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| 15 | +#include "mlir/Dialect/UB/IR/UBOps.h" |
15 | 16 | #include "mlir/IR/AffineExpr.h" |
16 | 17 | #include "mlir/IR/AffineMap.h" |
17 | 18 | #include "mlir/IR/Builders.h" |
@@ -445,14 +446,45 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> { |
445 | 446 | return success(replaced); |
446 | 447 | } |
447 | 448 | }; |
| 449 | + |
| 450 | +/// If the destination block of a conditional branch contains only |
| 451 | +/// ub.unreachable, unconditionally branch to the other destination. |
| 452 | +struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> { |
| 453 | + using OpRewritePattern<CondBranchOp>::OpRewritePattern; |
| 454 | + |
| 455 | + LogicalResult matchAndRewrite(CondBranchOp condbr, |
| 456 | + PatternRewriter &rewriter) const override { |
| 457 | + // If the "true" destination is unreachable, branch to the "false" |
| 458 | + // destination. |
| 459 | + Block *trueDest = condbr.getTrueDest(); |
| 460 | + Block *falseDest = condbr.getFalseDest(); |
| 461 | + if (llvm::hasSingleElement(*trueDest) && |
| 462 | + isa<ub::UnreachableOp>(trueDest->getTerminator())) { |
| 463 | + rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest, |
| 464 | + condbr.getFalseOperands()); |
| 465 | + return success(); |
| 466 | + } |
| 467 | + |
| 468 | + // If the "false" destination is unreachable, branch to the "true" |
| 469 | + // destination. |
| 470 | + if (llvm::hasSingleElement(*falseDest) && |
| 471 | + isa<ub::UnreachableOp>(falseDest->getTerminator())) { |
| 472 | + rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, |
| 473 | + condbr.getTrueOperands()); |
| 474 | + return success(); |
| 475 | + } |
| 476 | + |
| 477 | + return failure(); |
| 478 | + } |
| 479 | +}; |
448 | 480 | } // namespace |
449 | 481 |
|
450 | 482 | void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, |
451 | 483 | MLIRContext *context) { |
452 | 484 | results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch, |
453 | 485 | SimplifyCondBranchIdenticalSuccessors, |
454 | 486 | SimplifyCondBranchFromCondBranchOnSameCondition, |
455 | | - CondBranchTruthPropagation>(context); |
| 487 | + CondBranchTruthPropagation, DropUnreachableCondBranch>(context); |
456 | 488 | } |
457 | 489 |
|
458 | 490 | SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { |
|
0 commit comments