Skip to content

Commit 3de11e9

Browse files
[mlir][CF] Add ub.unreachable canonicalization (#169873)
Basic blocks with only a `ub.unreachable` terminator are unreachable. This commit adds a canonicalization pattern that folds to `cf.cond_br` to `cf.br` if one of the destination branches is unreachable.
1 parent a8cffb8 commit 3de11e9

File tree

3 files changed

+56
-1
lines changed

3 files changed

+56
-1
lines changed

mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect
1212
MLIRControlFlowInterfaces
1313
MLIRIR
1414
MLIRSideEffectInterfaces
15+
MLIRUBDialect
1516
)

mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Arith/IR/Arith.h"
1313
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
1414
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
15+
#include "mlir/Dialect/UB/IR/UBOps.h"
1516
#include "mlir/IR/AffineExpr.h"
1617
#include "mlir/IR/AffineMap.h"
1718
#include "mlir/IR/Builders.h"
@@ -445,14 +446,45 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
445446
return success(replaced);
446447
}
447448
};
449+
450+
/// If the destination block of a conditional branch contains only
451+
/// ub.unreachable, unconditionally branch to the other destination.
452+
struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> {
453+
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
454+
455+
LogicalResult matchAndRewrite(CondBranchOp condbr,
456+
PatternRewriter &rewriter) const override {
457+
// If the "true" destination is unreachable, branch to the "false"
458+
// destination.
459+
Block *trueDest = condbr.getTrueDest();
460+
Block *falseDest = condbr.getFalseDest();
461+
if (llvm::hasSingleElement(*trueDest) &&
462+
isa<ub::UnreachableOp>(trueDest->getTerminator())) {
463+
rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest,
464+
condbr.getFalseOperands());
465+
return success();
466+
}
467+
468+
// If the "false" destination is unreachable, branch to the "true"
469+
// destination.
470+
if (llvm::hasSingleElement(*falseDest) &&
471+
isa<ub::UnreachableOp>(falseDest->getTerminator())) {
472+
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest,
473+
condbr.getTrueOperands());
474+
return success();
475+
}
476+
477+
return failure();
478+
}
479+
};
448480
} // namespace
449481

450482
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
451483
MLIRContext *context) {
452484
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
453485
SimplifyCondBranchIdenticalSuccessors,
454486
SimplifyCondBranchFromCondBranchOnSameCondition,
455-
CondBranchTruthPropagation>(context);
487+
CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
456488
}
457489

458490
SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {

mlir/test/Dialect/ControlFlow/canonicalize.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,3 +634,25 @@ func.func @unsimplified_cycle_2(%c : i1) {
634634
^bb7:
635635
cf.br ^bb6
636636
}
637+
638+
// CHECK-LABEL: @drop_unreachable_branch_1
639+
// CHECK-NEXT: "test.foo"() : () -> ()
640+
// CHECK-NEXT: return
641+
func.func @drop_unreachable_branch_1(%c: i1) {
642+
cf.cond_br %c, ^bb1, ^bb2
643+
^bb1:
644+
"test.foo"() : () -> ()
645+
return
646+
^bb2:
647+
ub.unreachable
648+
}
649+
650+
// CHECK-LABEL: @drop_unreachable_branch_2
651+
// CHECK-NEXT: ub.unreachable
652+
func.func @drop_unreachable_branch_2(%c: i1) {
653+
cf.cond_br %c, ^bb1, ^bb2
654+
^bb1:
655+
ub.unreachable
656+
^bb2:
657+
ub.unreachable
658+
}

0 commit comments

Comments
 (0)