4848#include  " mlir/Transforms/DialectConversion.h" 
4949#include  " clang/CIR/Dialect/IR/CIRDialect.h" 
5050#include  " clang/CIR/Dialect/IR/CIRTypes.h" 
51+ #include  " clang/CIR/Interfaces/CIRLoopOpInterface.h" 
5152#include  " clang/CIR/LowerToLLVM.h" 
5253#include  " clang/CIR/LowerToMLIR.h" 
5354#include  " clang/CIR/LoweringHelpers.h" 
5455#include  " clang/CIR/Passes.h" 
5556#include  " llvm/ADT/STLExtras.h" 
56- #include  " llvm/Support/ErrorHandling.h" 
57- #include  " clang/CIR/Interfaces/CIRLoopOpInterface.h" 
58- #include  " clang/CIR/LowerToLLVM.h" 
59- #include  " clang/CIR/Passes.h" 
6057#include  " llvm/ADT/Sequence.h" 
6158#include  " llvm/ADT/SmallVector.h" 
6259#include  " llvm/ADT/TypeSwitch.h" 
6360#include  " llvm/IR/Value.h" 
61+ #include  " llvm/Support/ErrorHandling.h" 
6462#include  " llvm/Support/TimeProfiler.h" 
6563
6664using  namespace  cir ; 
@@ -946,8 +944,8 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern<cir::ScopeOp> {
946944    } else  {
947945      //  For scopes with results, use scf.execute_region
948946      SmallVector<mlir::Type> types;
949-       if  (mlir::failed (
950-               getTypeConverter ()-> convertTypes ( scopeOp->getResultTypes (), types)))
947+       if  (mlir::failed (getTypeConverter ()-> convertTypes ( 
948+               scopeOp->getResultTypes (), types)))
951949        return  mlir::failure ();
952950      auto  exec =
953951          rewriter.create <mlir::scf::ExecuteRegionOp>(scopeOp.getLoc (), types);
@@ -1515,28 +1513,117 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
15151513  }
15161514};
15171515
1516+ class  CIRSwitchOpLowering  : public  mlir ::OpConversionPattern<cir::SwitchOp> {
1517+ public: 
1518+   using  OpConversionPattern<cir::SwitchOp>::OpConversionPattern;
1519+ 
1520+   mlir::LogicalResult
1521+   matchAndRewrite (cir::SwitchOp op, OpAdaptor adaptor,
1522+                   mlir::ConversionPatternRewriter &rewriter) const  override  {
1523+     rewriter.setInsertionPointAfter (op);
1524+     llvm::SmallVector<CaseOp> cases;
1525+     if  (!op.isSimpleForm (cases))
1526+       llvm_unreachable (" NYI" 
1527+ 
1528+     llvm::SmallVector<int64_t > caseValues;
1529+     //  Maps the index of a CaseOp in `cases`, to the index in `caseValues`.
1530+     //  This is necessary because some CaseOp might carry 0 or multiple values.
1531+     llvm::DenseMap<size_t , unsigned > indexMap;
1532+     caseValues.reserve (cases.size ());
1533+     for  (auto  [i, caseOp] : llvm::enumerate (cases)) {
1534+       switch  (caseOp.getKind ()) {
1535+       case  CaseOpKind::Equal: {
1536+         auto  valueAttr = caseOp.getValue ()[0 ];
1537+         auto  value = cast<cir::IntAttr>(valueAttr);
1538+         indexMap[i] = caseValues.size ();
1539+         caseValues.push_back (value.getUInt ());
1540+         break ;
1541+       }
1542+       case  CaseOpKind::Default:
1543+         break ;
1544+       case  CaseOpKind::Range:
1545+       case  CaseOpKind::Anyof:
1546+         llvm_unreachable (" NYI" 
1547+       }
1548+     }
1549+ 
1550+     auto  operand = adaptor.getOperands ()[0 ];
1551+     //  `scf.index_switch` expects an index of type `index`.
1552+     auto  indexType = mlir::IndexType::get (getContext ());
1553+     auto  indexCast = rewriter.create <mlir::arith::IndexCastOp>(
1554+         op.getLoc (), indexType, operand);
1555+     auto  indexSwitch = rewriter.create <mlir::scf::IndexSwitchOp>(
1556+         op.getLoc (), mlir::TypeRange{}, indexCast, caseValues, cases.size ());
1557+ 
1558+     bool  metDefault = false ;
1559+     for  (auto  [i, caseOp] : llvm::enumerate (cases)) {
1560+       auto  ®ion = caseOp.getRegion ();
1561+       switch  (caseOp.getKind ()) {
1562+       case  CaseOpKind::Equal: {
1563+         auto  &caseRegion = indexSwitch.getCaseRegions ()[indexMap[i]];
1564+         rewriter.inlineRegionBefore (region, caseRegion, caseRegion.end ());
1565+         break ;
1566+       }
1567+       case  CaseOpKind::Default: {
1568+         auto  &defaultRegion = indexSwitch.getDefaultRegion ();
1569+         rewriter.inlineRegionBefore (region, defaultRegion, defaultRegion.end ());
1570+         metDefault = true ;
1571+         break ;
1572+       }
1573+       case  CaseOpKind::Range:
1574+       case  CaseOpKind::Anyof:
1575+         llvm_unreachable (" NYI" 
1576+       }
1577+     }
1578+ 
1579+     //  `scf.index_switch` expects its default region to contain exactly one
1580+     //  block. If we don't have a default region in `cir.switch`, we need to
1581+     //  supply it here.
1582+     if  (!metDefault) {
1583+       auto  &defaultRegion = indexSwitch.getDefaultRegion ();
1584+       mlir::Block *block =
1585+           rewriter.createBlock (&defaultRegion, defaultRegion.end ());
1586+       rewriter.setInsertionPointToEnd (block);
1587+       rewriter.create <mlir::scf::YieldOp>(op.getLoc ());
1588+     }
1589+ 
1590+     //  The final `cir.break` should be replaced to `scf.yield`.
1591+     //  After MLIRLoweringPrepare pass, every case must end with a `cir.break`.
1592+     for  (auto  ®ion : indexSwitch.getCaseRegions ()) {
1593+       auto  &lastBlock = region.back ();
1594+       auto  &lastOp = lastBlock.back ();
1595+       assert (isa<BreakOp>(lastOp));
1596+       rewriter.setInsertionPointAfter (&lastOp);
1597+       rewriter.replaceOpWithNewOp <mlir::scf::YieldOp>(&lastOp);
1598+     }
1599+ 
1600+     rewriter.replaceOp (op, indexSwitch);
1601+ 
1602+     return  mlir::success ();
1603+   }
1604+ };
1605+ 
15181606void  populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
15191607                                         mlir::TypeConverter &converter) {
15201608  patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
15211609
1522-   patterns
1523-       .add <CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1524-            CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1525-            CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1526-            CIRFuncOpLowering, CIRBrCondOpLowering,
1527-            CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1528-            CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1529-            CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1530-            CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1531-            CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1532-            CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1533-            CIRRoundOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1534-            CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1535-            CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1536-            CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1537-            CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
1538-            CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
1539-            CIRTrapOpLowering>(converter, patterns.getContext ());
1610+   patterns.add <
1611+       CIRSwitchOpLowering, CIRATanOpLowering, CIRCmpOpLowering,
1612+       CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1613+       CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1614+       CIRAllocaOpLowering, CIRFuncOpLowering, CIRBrCondOpLowering,
1615+       CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1616+       CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1617+       CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1618+       CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1619+       CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
1620+       CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering, CIRSinOpLowering,
1621+       CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1622+       CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1623+       CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
1624+       CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
1625+       CIRACosOpLowering, CIRASinOpLowering, CIRUnreachableOpLowering,
1626+       CIRTanOpLowering, CIRTrapOpLowering>(converter, patterns.getContext ());
15401627}
15411628
15421629static  mlir::TypeConverter prepareTypeConverter () {
@@ -1610,7 +1697,7 @@ void ConvertCIRToMLIRPass::runOnOperation() {
16101697  mlir::ModuleOp theModule = getOperation ();
16111698
16121699  auto  converter = prepareTypeConverter ();
1613-    
1700+ 
16141701  mlir::RewritePatternSet patterns (&getContext ());
16151702
16161703  populateCIRLoopToSCFConversionPatterns (patterns, converter);
@@ -1628,10 +1715,11 @@ void ConvertCIRToMLIRPass::runOnOperation() {
16281715  //  cir dialect, for example the `cir.continue`. If we marked cir as illegal
16291716  //  here, then MLIR would think any remaining `cir.continue` indicates a
16301717  //  failure, which is not what we want.
1631-   
1632-   patterns.add <CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering, CIRYieldOpLowering>(converter, context);
16331718
1634-   if  (mlir::failed (mlir::applyPartialConversion (theModule, target, 
1719+   patterns.add <CIRCastOpLowering, CIRIfOpLowering, CIRScopeOpLowering,
1720+                CIRYieldOpLowering>(converter, context);
1721+ 
1722+   if  (mlir::failed (mlir::applyPartialConversion (theModule, target,
16351723                                                std::move (patterns)))) {
16361724    signalPassFailure ();
16371725  }
@@ -1646,6 +1734,7 @@ mlir::ModuleOp lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule,
16461734
16471735  mlir::PassManager pm (mlirCtx);
16481736
1737+   pm.addPass (createMLIRLoweringPreparePass ());
16491738  pm.addPass (createConvertCIRToMLIRPass ());
16501739  pm.addPass (createConvertMLIRToLLVMPass ());
16511740
@@ -1712,6 +1801,8 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
17121801  llvm::TimeTraceScope scope (" Lower CIR To MLIR" 
17131802
17141803  mlir::PassManager pm (mlirCtx);
1804+ 
1805+   pm.addPass (createMLIRLoweringPreparePass ());
17151806  pm.addPass (createConvertCIRToMLIRPass ());
17161807
17171808  auto  result = !mlir::failed (pm.run (theModule));
0 commit comments