Skip to content

Conversation

@AmrDeveloper
Copy link
Member

Upstream TryCall Op as a prerequisite for Try Catch work

Issue #154992

@llvmbot llvmbot added clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project labels Oct 27, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 27, 2025

@llvm/pr-subscribers-clang

Author: Amr Hesham (AmrDeveloper)

Changes

Upstream TryCall Op as a prerequisite for Try Catch work

Issue #154992


Full diff: https://github.com/llvm/llvm-project/pull/165303.diff

5 Files Affected:

  • (modified) clang/include/clang/CIR/Dialect/IR/CIRDialect.td (+1)
  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+93-1)
  • (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+191-6)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+19-5)
  • (added) clang/test/CIR/IR/try-call.cir (+31)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index e91537186df59..34df9af7fc06d 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -44,6 +44,7 @@ def CIR_Dialect : Dialect {
     static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; }
     static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; }
     static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; }
+    static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; }
 
     void registerAttributes();
     void registerTypes();
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 2b361ed0982c6..8f3e25b3c9737 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2580,7 +2580,7 @@ def CIR_FuncOp : CIR_Op<"func", [
 }
 
 //===----------------------------------------------------------------------===//
-// CallOp
+// CallOp and TryCallOp
 //===----------------------------------------------------------------------===//
 
 def CIR_SideEffect : CIR_I32EnumAttr<
@@ -2707,6 +2707,98 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
   ];
 }
 
+def CIR_TryCallOp : CIR_CallOpBase<"try_call",[
+  DeclareOpInterfaceMethods<BranchOpInterface>,
+  Terminator, AttrSizedOperandSegments
+]> {
+  let summary = "try_call operation";
+
+  let description = [{
+    Mostly similar to cir.call but requires two destination
+    branches, one for handling exceptions in case its thrown and
+    the other one to follow on regular control-flow.
+
+    Example:
+
+    ```mlir
+    // Direct call
+    %result = cir.try_call @division(%a, %b) ^continue, ^landing_pad 
+      : (f32, f32) -> f32
+    ```
+  }];
+
+  let arguments = !con((ins
+    Variadic<CIR_AnyType>:$contOperands,
+    Variadic<CIR_AnyType>:$landingPadOperands
+  ), commonArgs);
+
+  let results = (outs Optional<CIR_AnyType>:$result);
+  let successors = (successor AnySuccessor:$cont, AnySuccessor:$landing_pad);
+
+  let skipDefaultBuilders = 1;
+
+  let builders = [
+    OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType,
+               "mlir::Block *":$cont, "mlir::Block *":$landing_pad,
+               CArg<"mlir::ValueRange", "{}">:$operands,
+               CArg<"mlir::ValueRange", "{}">:$contOperands,
+               CArg<"mlir::ValueRange", "{}">:$landingPadOperands,
+               CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+      $_state.addOperands(operands);
+      if (callee)
+        $_state.addAttribute("callee", callee);
+      if (resType && !isa<VoidType>(resType))
+        $_state.addTypes(resType);
+
+      $_state.addAttribute("side_effect",
+        SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+      // Handle branches
+      $_state.addOperands(contOperands);
+      $_state.addOperands(landingPadOperands);
+      // The TryCall ODS layout is: cont, landing_pad, operands.
+      llvm::copy(::llvm::ArrayRef<int32_t>({
+        static_cast<int32_t>(contOperands.size()),
+        static_cast<int32_t>(landingPadOperands.size()),
+        static_cast<int32_t>(operands.size())
+        }),
+        odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
+      $_state.addSuccessors(cont);
+      $_state.addSuccessors(landing_pad);
+    }]>,
+    OpBuilder<(ins "mlir::Value":$ind_target,
+               "FuncType":$fn_type,
+               "mlir::Block *":$cont, "mlir::Block *":$landing_pad,
+               CArg<"mlir::ValueRange", "{}">:$operands,
+               CArg<"mlir::ValueRange", "{}">:$contOperands,
+               CArg<"mlir::ValueRange", "{}">:$landingPadOperands, 
+               CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+      ::llvm::SmallVector<mlir::Value, 4> finalCallOperands({ind_target});
+      finalCallOperands.append(operands.begin(), operands.end());
+      $_state.addOperands(finalCallOperands);
+
+      if (!fn_type.hasVoidReturn())
+        $_state.addTypes(fn_type.getReturnType());
+
+      $_state.addAttribute("side_effect",
+        SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+      // Handle branches
+      $_state.addOperands(contOperands);
+      $_state.addOperands(landingPadOperands);
+      // The TryCall ODS layout is: cont, landing_pad, operands.
+      llvm::copy(::llvm::ArrayRef<int32_t>({
+        static_cast<int32_t>(contOperands.size()),
+        static_cast<int32_t>(landingPadOperands.size()),
+        static_cast<int32_t>(finalCallOperands.size())
+        }),
+        odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
+      $_state.addSuccessors(cont);
+      $_state.addSuccessors(landing_pad);
+    }]>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // CopyOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 2d2ef422bfaef..11074af3ef127 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -701,13 +701,78 @@ unsigned cir::CallOp::getNumArgOperands() {
   return this->getOperation()->getNumOperands();
 }
 
+static mlir::ParseResult
+parseTryCallBranches(mlir::OpAsmParser &parser, mlir::OperationState &result,
+                     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
+                         &continueOperands,
+                     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
+                         &landingPadOperands,
+                     llvm::SmallVectorImpl<mlir::Type> &continueTypes,
+                     llvm::SmallVectorImpl<mlir::Type> &landingPadTypes,
+                     llvm::SMLoc &continueOperandsLoc,
+                     llvm::SMLoc &landingPadOperandsLoc) {
+  mlir::Block *continueSuccessor = nullptr;
+  mlir::Block *landingPadSuccessor = nullptr;
+
+  if (parser.parseSuccessor(continueSuccessor))
+    return mlir::failure();
+
+  if (mlir::succeeded(parser.parseOptionalLParen())) {
+    continueOperandsLoc = parser.getCurrentLocation();
+    if (parser.parseOperandList(continueOperands))
+      return mlir::failure();
+    if (parser.parseColon())
+      return mlir::failure();
+
+    if (parser.parseTypeList(continueTypes))
+      return mlir::failure();
+    if (parser.parseRParen())
+      return mlir::failure();
+  }
+
+  if (parser.parseComma())
+    return mlir::failure();
+
+  if (parser.parseSuccessor(landingPadSuccessor))
+    return mlir::failure();
+
+  if (mlir::succeeded(parser.parseOptionalLParen())) {
+    landingPadOperandsLoc = parser.getCurrentLocation();
+    if (parser.parseOperandList(landingPadOperands))
+      return mlir::failure();
+    if (parser.parseColon())
+      return mlir::failure();
+
+    if (parser.parseTypeList(landingPadTypes))
+      return mlir::failure();
+    if (parser.parseRParen())
+      return mlir::failure();
+  }
+
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return mlir::failure();
+
+  result.addSuccessors(continueSuccessor);
+  result.addSuccessors(landingPadSuccessor);
+  return mlir::success();
+}
+
 static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
-                                         mlir::OperationState &result) {
+                                         mlir::OperationState &result,
+                                         bool hasDestinationBlocks = false) {
   llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
   llvm::SMLoc opsLoc;
   mlir::FlatSymbolRefAttr calleeAttr;
   llvm::ArrayRef<mlir::Type> allResultTypes;
 
+  // TryCall control flow related
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> continueOperands;
+  llvm::SMLoc continueOperandsLoc;
+  llvm::SmallVector<mlir::Type, 1> continueTypes;
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> landingPadOperands;
+  llvm::SMLoc landingPadOperandsLoc;
+  llvm::SmallVector<mlir::Type, 1> landingPadTypes;
+
   // If we cannot parse a string callee, it means this is an indirect call.
   if (!parser
            .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
@@ -729,6 +794,14 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
   if (parser.parseRParen())
     return mlir::failure();
 
+  if (hasDestinationBlocks &&
+      parseTryCallBranches(parser, result, continueOperands, landingPadOperands,
+                           continueTypes, landingPadTypes, continueOperandsLoc,
+                           landingPadOperandsLoc)
+          .failed()) {
+    return ::mlir::failure();
+  }
+
   if (parser.parseOptionalKeyword("nothrow").succeeded())
     result.addAttribute(CIRDialect::getNoThrowAttrName(),
                         mlir::UnitAttr::get(parser.getContext()));
@@ -761,6 +834,24 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
   if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
     return mlir::failure();
 
+  if (hasDestinationBlocks) {
+    // The TryCall ODS layout is: cont, landing_pad, operands.
+    llvm::copy(::llvm::ArrayRef<int32_t>(
+                   {static_cast<int32_t>(continueOperands.size()),
+                    static_cast<int32_t>(landingPadOperands.size()),
+                    static_cast<int32_t>(ops.size())}),
+               result.getOrAddProperties<cir::TryCallOp::Properties>()
+                   .operandSegmentSizes.begin());
+
+    if (parser.resolveOperands(continueOperands, continueTypes,
+                               continueOperandsLoc, result.operands))
+      return ::mlir::failure();
+
+    if (parser.resolveOperands(landingPadOperands, landingPadTypes,
+                               landingPadOperandsLoc, result.operands))
+      return ::mlir::failure();
+  }
+
   return mlir::success();
 }
 
@@ -768,7 +859,9 @@ static void printCallCommon(mlir::Operation *op,
                             mlir::FlatSymbolRefAttr calleeSym,
                             mlir::Value indirectCallee,
                             mlir::OpAsmPrinter &printer, bool isNothrow,
-                            cir::SideEffect sideEffect) {
+                            cir::SideEffect sideEffect,
+                            mlir::Block *cont = nullptr,
+                            mlir::Block *landingPad = nullptr) {
   printer << ' ';
 
   auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
@@ -782,8 +875,35 @@ static void printCallCommon(mlir::Operation *op,
     assert(indirectCallee);
     printer << indirectCallee;
   }
+
   printer << "(" << ops << ")";
 
+  if (cont) {
+    assert(landingPad && "expected two successors");
+    auto tryCall = dyn_cast<cir::TryCallOp>(op);
+    assert(tryCall && "regular calls do not branch");
+    printer << ' ' << tryCall.getCont();
+    if (!tryCall.getContOperands().empty()) {
+      printer << "(";
+      printer << tryCall.getContOperands();
+      printer << ' ' << ":";
+      printer << ' ';
+      printer << tryCall.getContOperands().getTypes();
+      printer << ")";
+    }
+    printer << ",";
+    printer << ' ';
+    printer << tryCall.getLandingPad();
+    if (!tryCall.getLandingPadOperands().empty()) {
+      printer << "(";
+      printer << tryCall.getLandingPadOperands();
+      printer << ' ' << ":";
+      printer << ' ';
+      printer << tryCall.getLandingPadOperands().getTypes();
+      printer << ")";
+    }
+  }
+
   if (isNothrow)
     printer << " nothrow";
 
@@ -793,10 +913,11 @@ static void printCallCommon(mlir::Operation *op,
     printer << ")";
   }
 
-  printer.printOptionalAttrDict(op->getAttrs(),
-                                {CIRDialect::getCalleeAttrName(),
-                                 CIRDialect::getNoThrowAttrName(),
-                                 CIRDialect::getSideEffectAttrName()});
+  llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs = {
+      CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
+      CIRDialect::getSideEffectAttrName(),
+      CIRDialect::getOperandSegmentSizesAttrName()};
+  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 
   printer << " : ";
   printer.printFunctionalType(op->getOperands().getTypes(),
@@ -878,6 +999,70 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return verifyCallCommInSymbolUses(*this, symbolTable);
 }
 
+//===----------------------------------------------------------------------===//
+// TryCallOp
+//===----------------------------------------------------------------------===//
+
+mlir::OperandRange cir::TryCallOp::getArgOperands() {
+  if (isIndirect())
+    return getArgs().drop_front(1);
+  return getArgs();
+}
+
+mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
+  mlir::MutableOperandRange args = getArgsMutable();
+  if (isIndirect())
+    return args.slice(1, args.size() - 1);
+  return args;
+}
+
+mlir::Value cir::TryCallOp::getIndirectCall() {
+  assert(isIndirect());
+  return getOperand(0);
+}
+
+/// Return the operand at index 'i'.
+Value cir::TryCallOp::getArgOperand(unsigned i) {
+  if (isIndirect())
+    ++i;
+  return getOperand(i);
+}
+
+/// Return the number of operands.
+unsigned cir::TryCallOp::getNumArgOperands() {
+  if (isIndirect())
+    return this->getOperation()->getNumOperands() - 1;
+  return this->getOperation()->getNumOperands();
+}
+
+LogicalResult
+cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return verifyCallCommInSymbolUses(*this, symbolTable);
+}
+
+mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
+                                        mlir::OperationState &result) {
+  return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
+}
+
+void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
+  mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
+  cir::SideEffect sideEffect = getSideEffect();
+  printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
+                  sideEffect, getCont(), getLandingPad());
+}
+
+mlir::SuccessorOperands cir::TryCallOp::getSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  if (index == 0)
+    return SuccessorOperands(getContOperandsMutable());
+  if (index == 1)
+    return SuccessorOperands(getLandingPadOperandsMutable());
+
+  // index == 2
+  return SuccessorOperands(getArgOperandsMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // ReturnOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5a6193fa8d840..12f3db01c77d8 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1385,7 +1385,9 @@ static mlir::LogicalResult
 rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
                     mlir::ConversionPatternRewriter &rewriter,
                     const mlir::TypeConverter *converter,
-                    mlir::FlatSymbolRefAttr calleeAttr) {
+                    mlir::FlatSymbolRefAttr calleeAttr,
+                    mlir::Block *continueBlock = nullptr,
+                    mlir::Block *landingPadBlock = nullptr) {
   llvm::SmallVector<mlir::Type, 8> llvmResults;
   mlir::ValueTypeRange<mlir::ResultRange> cirResults = op->getResultTypes();
   auto call = cast<cir::CIRCallOpInterface>(op);
@@ -1414,7 +1416,7 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
       llvmFnTy = converter->convertType<mlir::LLVM::LLVMFunctionType>(
           fn.getFunctionType());
       assert(llvmFnTy && "Failed to convert function type");
-    } else if (auto alias = mlir::cast<mlir::LLVM::AliasOp>(callee)) {
+    } else if (auto alias = mlir::dyn_cast<mlir::LLVM::AliasOp>(callee)) {
       // If the callee was an alias. In that case,
       // we need to prepend the address of the alias to the operands. The
       // way aliases work in the LLVM dialect is a little counter-intuitive.
@@ -1452,17 +1454,21 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
         converter->convertType(calleeFuncTy));
   }
 
-  assert(!cir::MissingFeatures::opCallLandingPad());
-  assert(!cir::MissingFeatures::opCallContinueBlock());
   assert(!cir::MissingFeatures::opCallCallConv());
 
+  if (landingPadBlock) {
+    rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
+        op, llvmFnTy, calleeAttr, callOperands, continueBlock,
+        mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
+    return mlir::success();
+  }
+
   auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
       op, llvmFnTy, calleeAttr, callOperands);
   if (memoryEffects)
     newOp.setMemoryEffectsAttr(memoryEffects);
   newOp.setNoUnwind(noUnwind);
   newOp.setWillReturn(willReturn);
-
   return mlir::success();
 }
 
@@ -1473,6 +1479,14 @@ mlir::LogicalResult CIRToLLVMCallOpLowering::matchAndRewrite(
                              getTypeConverter(), op.getCalleeAttr());
 }
 
+mlir::LogicalResult CIRToLLVMTryCallOpLowering::matchAndRewrite(
+    cir::TryCallOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  return rewriteCallOrInvoke(op.getOperation(), adaptor.getOperands(), rewriter,
+                             getTypeConverter(), op.getCalleeAttr(),
+                             op.getCont(), op.getLandingPad());
+}
+
 mlir::LogicalResult CIRToLLVMReturnAddrOpLowering::matchAndRewrite(
     cir::ReturnAddrOp op, OpAdaptor adaptor,
     mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/test/CIR/IR/try-call.cir b/clang/test/CIR/IR/try-call.cir
new file mode 100644
index 0000000000000..6c23d3add15c8
--- /dev/null
+++ b/clang/test/CIR/IR/try-call.cir
@@ -0,0 +1,31 @@
+// RUN: cir-opt %s --verify-roundtrip | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+
+cir.func private @division(%a: !s32i, %b: !s32i) -> !s32i
+
+cir.func @flatten_structure_with_try_call_op() {
+   %a = cir.const #cir.int<1> : !s32i
+   %b = cir.const #cir.int<2> : !s32i
+   %3 = cir.try_call @division(%a, %b) ^continue, ^landing_pad : (!s32i, !s32i) -> !s32i
+ ^continue:
+   cir.br ^landing_pad
+ ^landing_pad:
+   cir.return
+}
+
+// CHECK: cir.func private @division(!s32i, !s32i) -> !s32i
+
+// CHECK: cir.func @flatten_structure_with_try_call_op() {
+// CHECK-NEXT:   %[[CONST_0:.*]] = cir.const #cir.int<1> : !s32i
+// CHECK-NEXT:   %[[CONST_1:.*]] = cir.const #cir.int<2> : !s32i
+// CHECK-NEXT:   %[[CALL:.*]] = cir.try_call @division(%0, %1) ^[[CONTINUE:.*]], ^[[LANDING_PAD:.*]] : (!s32i, !s32i) -> !s32i
+// CHECK-NEXT: ^[[CONTINUE]]:
+// CHECK-NEXT:   cir.br ^[[LANDING_PAD]]
+// CHECK-NEXT: ^[[LANDING_PAD]]:
+// CHECK-NEXT:   cir.return
+// CHECK-NEXT: }
+
+}

@llvmbot
Copy link
Member

llvmbot commented Oct 27, 2025

@llvm/pr-subscribers-clangir

Author: Amr Hesham (AmrDeveloper)

Changes

Upstream TryCall Op as a prerequisite for Try Catch work

Issue #154992


Full diff: https://github.com/llvm/llvm-project/pull/165303.diff

5 Files Affected:

  • (modified) clang/include/clang/CIR/Dialect/IR/CIRDialect.td (+1)
  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+93-1)
  • (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+191-6)
  • (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+19-5)
  • (added) clang/test/CIR/IR/try-call.cir (+31)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index e91537186df59..34df9af7fc06d 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -44,6 +44,7 @@ def CIR_Dialect : Dialect {
     static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; }
     static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; }
     static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; }
+    static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; }
 
     void registerAttributes();
     void registerTypes();
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 2b361ed0982c6..8f3e25b3c9737 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2580,7 +2580,7 @@ def CIR_FuncOp : CIR_Op<"func", [
 }
 
 //===----------------------------------------------------------------------===//
-// CallOp
+// CallOp and TryCallOp
 //===----------------------------------------------------------------------===//
 
 def CIR_SideEffect : CIR_I32EnumAttr<
@@ -2707,6 +2707,98 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
   ];
 }
 
+def CIR_TryCallOp : CIR_CallOpBase<"try_call",[
+  DeclareOpInterfaceMethods<BranchOpInterface>,
+  Terminator, AttrSizedOperandSegments
+]> {
+  let summary = "try_call operation";
+
+  let description = [{
+    Mostly similar to cir.call but requires two destination
+    branches, one for handling exceptions in case its thrown and
+    the other one to follow on regular control-flow.
+
+    Example:
+
+    ```mlir
+    // Direct call
+    %result = cir.try_call @division(%a, %b) ^continue, ^landing_pad 
+      : (f32, f32) -> f32
+    ```
+  }];
+
+  let arguments = !con((ins
+    Variadic<CIR_AnyType>:$contOperands,
+    Variadic<CIR_AnyType>:$landingPadOperands
+  ), commonArgs);
+
+  let results = (outs Optional<CIR_AnyType>:$result);
+  let successors = (successor AnySuccessor:$cont, AnySuccessor:$landing_pad);
+
+  let skipDefaultBuilders = 1;
+
+  let builders = [
+    OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType,
+               "mlir::Block *":$cont, "mlir::Block *":$landing_pad,
+               CArg<"mlir::ValueRange", "{}">:$operands,
+               CArg<"mlir::ValueRange", "{}">:$contOperands,
+               CArg<"mlir::ValueRange", "{}">:$landingPadOperands,
+               CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+      $_state.addOperands(operands);
+      if (callee)
+        $_state.addAttribute("callee", callee);
+      if (resType && !isa<VoidType>(resType))
+        $_state.addTypes(resType);
+
+      $_state.addAttribute("side_effect",
+        SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+      // Handle branches
+      $_state.addOperands(contOperands);
+      $_state.addOperands(landingPadOperands);
+      // The TryCall ODS layout is: cont, landing_pad, operands.
+      llvm::copy(::llvm::ArrayRef<int32_t>({
+        static_cast<int32_t>(contOperands.size()),
+        static_cast<int32_t>(landingPadOperands.size()),
+        static_cast<int32_t>(operands.size())
+        }),
+        odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
+      $_state.addSuccessors(cont);
+      $_state.addSuccessors(landing_pad);
+    }]>,
+    OpBuilder<(ins "mlir::Value":$ind_target,
+               "FuncType":$fn_type,
+               "mlir::Block *":$cont, "mlir::Block *":$landing_pad,
+               CArg<"mlir::ValueRange", "{}">:$operands,
+               CArg<"mlir::ValueRange", "{}">:$contOperands,
+               CArg<"mlir::ValueRange", "{}">:$landingPadOperands, 
+               CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+      ::llvm::SmallVector<mlir::Value, 4> finalCallOperands({ind_target});
+      finalCallOperands.append(operands.begin(), operands.end());
+      $_state.addOperands(finalCallOperands);
+
+      if (!fn_type.hasVoidReturn())
+        $_state.addTypes(fn_type.getReturnType());
+
+      $_state.addAttribute("side_effect",
+        SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+      // Handle branches
+      $_state.addOperands(contOperands);
+      $_state.addOperands(landingPadOperands);
+      // The TryCall ODS layout is: cont, landing_pad, operands.
+      llvm::copy(::llvm::ArrayRef<int32_t>({
+        static_cast<int32_t>(contOperands.size()),
+        static_cast<int32_t>(landingPadOperands.size()),
+        static_cast<int32_t>(finalCallOperands.size())
+        }),
+        odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
+      $_state.addSuccessors(cont);
+      $_state.addSuccessors(landing_pad);
+    }]>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // CopyOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 2d2ef422bfaef..11074af3ef127 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -701,13 +701,78 @@ unsigned cir::CallOp::getNumArgOperands() {
   return this->getOperation()->getNumOperands();
 }
 
+static mlir::ParseResult
+parseTryCallBranches(mlir::OpAsmParser &parser, mlir::OperationState &result,
+                     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
+                         &continueOperands,
+                     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
+                         &landingPadOperands,
+                     llvm::SmallVectorImpl<mlir::Type> &continueTypes,
+                     llvm::SmallVectorImpl<mlir::Type> &landingPadTypes,
+                     llvm::SMLoc &continueOperandsLoc,
+                     llvm::SMLoc &landingPadOperandsLoc) {
+  mlir::Block *continueSuccessor = nullptr;
+  mlir::Block *landingPadSuccessor = nullptr;
+
+  if (parser.parseSuccessor(continueSuccessor))
+    return mlir::failure();
+
+  if (mlir::succeeded(parser.parseOptionalLParen())) {
+    continueOperandsLoc = parser.getCurrentLocation();
+    if (parser.parseOperandList(continueOperands))
+      return mlir::failure();
+    if (parser.parseColon())
+      return mlir::failure();
+
+    if (parser.parseTypeList(continueTypes))
+      return mlir::failure();
+    if (parser.parseRParen())
+      return mlir::failure();
+  }
+
+  if (parser.parseComma())
+    return mlir::failure();
+
+  if (parser.parseSuccessor(landingPadSuccessor))
+    return mlir::failure();
+
+  if (mlir::succeeded(parser.parseOptionalLParen())) {
+    landingPadOperandsLoc = parser.getCurrentLocation();
+    if (parser.parseOperandList(landingPadOperands))
+      return mlir::failure();
+    if (parser.parseColon())
+      return mlir::failure();
+
+    if (parser.parseTypeList(landingPadTypes))
+      return mlir::failure();
+    if (parser.parseRParen())
+      return mlir::failure();
+  }
+
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return mlir::failure();
+
+  result.addSuccessors(continueSuccessor);
+  result.addSuccessors(landingPadSuccessor);
+  return mlir::success();
+}
+
 static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
-                                         mlir::OperationState &result) {
+                                         mlir::OperationState &result,
+                                         bool hasDestinationBlocks = false) {
   llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
   llvm::SMLoc opsLoc;
   mlir::FlatSymbolRefAttr calleeAttr;
   llvm::ArrayRef<mlir::Type> allResultTypes;
 
+  // TryCall control flow related
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> continueOperands;
+  llvm::SMLoc continueOperandsLoc;
+  llvm::SmallVector<mlir::Type, 1> continueTypes;
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> landingPadOperands;
+  llvm::SMLoc landingPadOperandsLoc;
+  llvm::SmallVector<mlir::Type, 1> landingPadTypes;
+
   // If we cannot parse a string callee, it means this is an indirect call.
   if (!parser
            .parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
@@ -729,6 +794,14 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
   if (parser.parseRParen())
     return mlir::failure();
 
+  if (hasDestinationBlocks &&
+      parseTryCallBranches(parser, result, continueOperands, landingPadOperands,
+                           continueTypes, landingPadTypes, continueOperandsLoc,
+                           landingPadOperandsLoc)
+          .failed()) {
+    return ::mlir::failure();
+  }
+
   if (parser.parseOptionalKeyword("nothrow").succeeded())
     result.addAttribute(CIRDialect::getNoThrowAttrName(),
                         mlir::UnitAttr::get(parser.getContext()));
@@ -761,6 +834,24 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
   if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
     return mlir::failure();
 
+  if (hasDestinationBlocks) {
+    // The TryCall ODS layout is: cont, landing_pad, operands.
+    llvm::copy(::llvm::ArrayRef<int32_t>(
+                   {static_cast<int32_t>(continueOperands.size()),
+                    static_cast<int32_t>(landingPadOperands.size()),
+                    static_cast<int32_t>(ops.size())}),
+               result.getOrAddProperties<cir::TryCallOp::Properties>()
+                   .operandSegmentSizes.begin());
+
+    if (parser.resolveOperands(continueOperands, continueTypes,
+                               continueOperandsLoc, result.operands))
+      return ::mlir::failure();
+
+    if (parser.resolveOperands(landingPadOperands, landingPadTypes,
+                               landingPadOperandsLoc, result.operands))
+      return ::mlir::failure();
+  }
+
   return mlir::success();
 }
 
@@ -768,7 +859,9 @@ static void printCallCommon(mlir::Operation *op,
                             mlir::FlatSymbolRefAttr calleeSym,
                             mlir::Value indirectCallee,
                             mlir::OpAsmPrinter &printer, bool isNothrow,
-                            cir::SideEffect sideEffect) {
+                            cir::SideEffect sideEffect,
+                            mlir::Block *cont = nullptr,
+                            mlir::Block *landingPad = nullptr) {
   printer << ' ';
 
   auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
@@ -782,8 +875,35 @@ static void printCallCommon(mlir::Operation *op,
     assert(indirectCallee);
     printer << indirectCallee;
   }
+
   printer << "(" << ops << ")";
 
+  if (cont) {
+    assert(landingPad && "expected two successors");
+    auto tryCall = dyn_cast<cir::TryCallOp>(op);
+    assert(tryCall && "regular calls do not branch");
+    printer << ' ' << tryCall.getCont();
+    if (!tryCall.getContOperands().empty()) {
+      printer << "(";
+      printer << tryCall.getContOperands();
+      printer << ' ' << ":";
+      printer << ' ';
+      printer << tryCall.getContOperands().getTypes();
+      printer << ")";
+    }
+    printer << ",";
+    printer << ' ';
+    printer << tryCall.getLandingPad();
+    if (!tryCall.getLandingPadOperands().empty()) {
+      printer << "(";
+      printer << tryCall.getLandingPadOperands();
+      printer << ' ' << ":";
+      printer << ' ';
+      printer << tryCall.getLandingPadOperands().getTypes();
+      printer << ")";
+    }
+  }
+
   if (isNothrow)
     printer << " nothrow";
 
@@ -793,10 +913,11 @@ static void printCallCommon(mlir::Operation *op,
     printer << ")";
   }
 
-  printer.printOptionalAttrDict(op->getAttrs(),
-                                {CIRDialect::getCalleeAttrName(),
-                                 CIRDialect::getNoThrowAttrName(),
-                                 CIRDialect::getSideEffectAttrName()});
+  llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs = {
+      CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
+      CIRDialect::getSideEffectAttrName(),
+      CIRDialect::getOperandSegmentSizesAttrName()};
+  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 
   printer << " : ";
   printer.printFunctionalType(op->getOperands().getTypes(),
@@ -878,6 +999,70 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return verifyCallCommInSymbolUses(*this, symbolTable);
 }
 
+//===----------------------------------------------------------------------===//
+// TryCallOp
+//===----------------------------------------------------------------------===//
+
+mlir::OperandRange cir::TryCallOp::getArgOperands() {
+  if (isIndirect())
+    return getArgs().drop_front(1);
+  return getArgs();
+}
+
+mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
+  mlir::MutableOperandRange args = getArgsMutable();
+  if (isIndirect())
+    return args.slice(1, args.size() - 1);
+  return args;
+}
+
+mlir::Value cir::TryCallOp::getIndirectCall() {
+  assert(isIndirect());
+  return getOperand(0);
+}
+
+/// Return the operand at index 'i'.
+Value cir::TryCallOp::getArgOperand(unsigned i) {
+  if (isIndirect())
+    ++i;
+  return getOperand(i);
+}
+
+/// Return the number of operands.
+unsigned cir::TryCallOp::getNumArgOperands() {
+  if (isIndirect())
+    return this->getOperation()->getNumOperands() - 1;
+  return this->getOperation()->getNumOperands();
+}
+
+LogicalResult
+cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  return verifyCallCommInSymbolUses(*this, symbolTable);
+}
+
+mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
+                                        mlir::OperationState &result) {
+  return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
+}
+
+void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
+  mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
+  cir::SideEffect sideEffect = getSideEffect();
+  printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
+                  sideEffect, getCont(), getLandingPad());
+}
+
+mlir::SuccessorOperands cir::TryCallOp::getSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  if (index == 0)
+    return SuccessorOperands(getContOperandsMutable());
+  if (index == 1)
+    return SuccessorOperands(getLandingPadOperandsMutable());
+
+  // index == 2
+  return SuccessorOperands(getArgOperandsMutable());
+}
+
 //===----------------------------------------------------------------------===//
 // ReturnOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5a6193fa8d840..12f3db01c77d8 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1385,7 +1385,9 @@ static mlir::LogicalResult
 rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
                     mlir::ConversionPatternRewriter &rewriter,
                     const mlir::TypeConverter *converter,
-                    mlir::FlatSymbolRefAttr calleeAttr) {
+                    mlir::FlatSymbolRefAttr calleeAttr,
+                    mlir::Block *continueBlock = nullptr,
+                    mlir::Block *landingPadBlock = nullptr) {
   llvm::SmallVector<mlir::Type, 8> llvmResults;
   mlir::ValueTypeRange<mlir::ResultRange> cirResults = op->getResultTypes();
   auto call = cast<cir::CIRCallOpInterface>(op);
@@ -1414,7 +1416,7 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
       llvmFnTy = converter->convertType<mlir::LLVM::LLVMFunctionType>(
           fn.getFunctionType());
       assert(llvmFnTy && "Failed to convert function type");
-    } else if (auto alias = mlir::cast<mlir::LLVM::AliasOp>(callee)) {
+    } else if (auto alias = mlir::dyn_cast<mlir::LLVM::AliasOp>(callee)) {
       // If the callee was an alias. In that case,
       // we need to prepend the address of the alias to the operands. The
       // way aliases work in the LLVM dialect is a little counter-intuitive.
@@ -1452,17 +1454,21 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
         converter->convertType(calleeFuncTy));
   }
 
-  assert(!cir::MissingFeatures::opCallLandingPad());
-  assert(!cir::MissingFeatures::opCallContinueBlock());
   assert(!cir::MissingFeatures::opCallCallConv());
 
+  if (landingPadBlock) {
+    rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
+        op, llvmFnTy, calleeAttr, callOperands, continueBlock,
+        mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
+    return mlir::success();
+  }
+
   auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
       op, llvmFnTy, calleeAttr, callOperands);
   if (memoryEffects)
     newOp.setMemoryEffectsAttr(memoryEffects);
   newOp.setNoUnwind(noUnwind);
   newOp.setWillReturn(willReturn);
-
   return mlir::success();
 }
 
@@ -1473,6 +1479,14 @@ mlir::LogicalResult CIRToLLVMCallOpLowering::matchAndRewrite(
                              getTypeConverter(), op.getCalleeAttr());
 }
 
+mlir::LogicalResult CIRToLLVMTryCallOpLowering::matchAndRewrite(
+    cir::TryCallOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  return rewriteCallOrInvoke(op.getOperation(), adaptor.getOperands(), rewriter,
+                             getTypeConverter(), op.getCalleeAttr(),
+                             op.getCont(), op.getLandingPad());
+}
+
 mlir::LogicalResult CIRToLLVMReturnAddrOpLowering::matchAndRewrite(
     cir::ReturnAddrOp op, OpAdaptor adaptor,
     mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/test/CIR/IR/try-call.cir b/clang/test/CIR/IR/try-call.cir
new file mode 100644
index 0000000000000..6c23d3add15c8
--- /dev/null
+++ b/clang/test/CIR/IR/try-call.cir
@@ -0,0 +1,31 @@
+// RUN: cir-opt %s --verify-roundtrip | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+
+cir.func private @division(%a: !s32i, %b: !s32i) -> !s32i
+
+cir.func @flatten_structure_with_try_call_op() {
+   %a = cir.const #cir.int<1> : !s32i
+   %b = cir.const #cir.int<2> : !s32i
+   %3 = cir.try_call @division(%a, %b) ^continue, ^landing_pad : (!s32i, !s32i) -> !s32i
+ ^continue:
+   cir.br ^landing_pad
+ ^landing_pad:
+   cir.return
+}
+
+// CHECK: cir.func private @division(!s32i, !s32i) -> !s32i
+
+// CHECK: cir.func @flatten_structure_with_try_call_op() {
+// CHECK-NEXT:   %[[CONST_0:.*]] = cir.const #cir.int<1> : !s32i
+// CHECK-NEXT:   %[[CONST_1:.*]] = cir.const #cir.int<2> : !s32i
+// CHECK-NEXT:   %[[CALL:.*]] = cir.try_call @division(%0, %1) ^[[CONTINUE:.*]], ^[[LANDING_PAD:.*]] : (!s32i, !s32i) -> !s32i
+// CHECK-NEXT: ^[[CONTINUE]]:
+// CHECK-NEXT:   cir.br ^[[LANDING_PAD]]
+// CHECK-NEXT: ^[[LANDING_PAD]]:
+// CHECK-NEXT:   cir.return
+// CHECK-NEXT: }
+
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants