@@ -528,7 +528,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
528528 for (auto continueOp : continues) {
529529 bool nested = false ;
530530 // When there is another loop between this WhileOp and the ContinueOp,
531- // we shouldn't change that loop instead.
531+ // we should change that loop instead.
532532 for (mlir::Operation *parent = continueOp->getParentOp ();
533533 parent != whileOp; parent = parent->getParentOp ()) {
534534 if (isa<WhileOp>(parent)) {
@@ -570,6 +570,73 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
570570 }
571571 }
572572
573+ void rewriteBreak (mlir::scf::WhileOp whileOp,
574+ mlir::ConversionPatternRewriter &rewriter) const {
575+ // Collect all BreakOp inside this while.
576+ llvm::SmallVector<cir::BreakOp> breaks;
577+ whileOp->walk ([&](mlir::Operation *op) {
578+ if (auto breakOp = dyn_cast<BreakOp>(op))
579+ breaks.push_back (breakOp);
580+ });
581+
582+ if (breaks.empty ())
583+ return ;
584+
585+ for (auto breakOp : breaks) {
586+ // When there is another loop between this WhileOp and the BreakOp,
587+ // we should change that loop instead.
588+ if (breakOp->getParentOfType <mlir::scf::WhileOp>() != whileOp)
589+ continue ;
590+
591+ // Similar to the case of ContinueOp, when there is an `IfOp`,
592+ // we need to take special care.
593+ for (mlir::Operation *parent = breakOp->getParentOp (); parent != whileOp;
594+ parent = parent->getParentOp ()) {
595+ if (auto ifOp = dyn_cast<cir::IfOp>(parent))
596+ llvm_unreachable (" NYI" );
597+ }
598+
599+ // Operations after this BreakOp has to be removed.
600+ for (mlir::Operation *runner = breakOp->getNextNode (); runner;) {
601+ mlir::Operation *next = runner->getNextNode ();
602+ runner->erase ();
603+ runner = next;
604+ }
605+
606+ // Blocks after this BreakOp also has to be removed.
607+ for (mlir::Block *block = breakOp->getBlock ()->getNextNode (); block;) {
608+ mlir::Block *next = block->getNextNode ();
609+ block->erase ();
610+ block = next;
611+ }
612+
613+ // We know this BreakOp isn't nested in any IfOp.
614+ // Therefore, the loop is executed only once.
615+ // We pull everything out of the loop.
616+
617+ auto &beforeOps = whileOp.getBeforeBody ()->getOperations ();
618+ for (mlir::Operation *op = &*beforeOps.begin (); op;) {
619+ if (isa<ConditionOp>(op))
620+ break ;
621+ auto *next = op->getNextNode ();
622+ op->moveBefore (whileOp);
623+ op = next;
624+ }
625+
626+ auto &afterOps = whileOp.getAfterBody ()->getOperations ();
627+ for (mlir::Operation *op = &*afterOps.begin (); op;) {
628+ if (isa<YieldOp>(op))
629+ break ;
630+ auto *next = op->getNextNode ();
631+ op->moveBefore (whileOp);
632+ op = next;
633+ }
634+
635+ // The loop itself should now be removed.
636+ rewriter.eraseOp (whileOp);
637+ }
638+ }
639+
573640public:
574641 using OpConversionPattern<cir::WhileOp>::OpConversionPattern;
575642
@@ -579,6 +646,7 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
579646 SCFWhileLoop loop (op, adaptor, &rewriter);
580647 auto whileOp = loop.transferToSCFWhileOp ();
581648 rewriteContinue (whileOp, rewriter);
649+ rewriteBreak (whileOp, rewriter);
582650 rewriter.eraseOp (op);
583651 return mlir::success ();
584652 }
0 commit comments