Skip to content

Add __builtin_wasm_test_function_pointer_signature #147076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

hoodmane
Copy link
Contributor

@hoodmane hoodmane commented Jul 4, 2025

This uses ref.test to check whether the function pointer's runtime type
matches its static type. If so, then calling it won't trap with "indirect
call signature mismatch". This would be very useful here:
https://github.com/python/cpython/blob/main/Python/emscripten_trampoline.c
and would allow us to fix function pointer mismatches on the WASI target
and the Emscripten target in a uniform way.

This is on top of #139642.

hoodmane added 3 commits May 12, 2025 19:13
This uses ref.test to check whether the function pointer's runtime type
matches its static type. If so, then calling it won't trap with "indirect
call signature mismatch". This would be very useful here:
https://github.com/python/cpython/blob/main/Python/emscripten_trampoline.c
and would allow us to fix function pointer mismatches on the WASI target
and the Emscripten target in a uniform way.
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:WebAssembly clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen IR generation bugs: mangling, exceptions, etc. mc Machine (object) code llvm:ir labels Jul 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2025

@llvm/pr-subscribers-backend-webassembly
@llvm/pr-subscribers-clang-codegen

@llvm/pr-subscribers-clang

Author: Hood Chatham (hoodmane)

Changes

This uses ref.test to check whether the function pointer's runtime type
matches its static type. If so, then calling it won't trap with "indirect
call signature mismatch". This would be very useful here:
https://github.com/python/cpython/blob/main/Python/emscripten_trampoline.c
and would allow us to fix function pointer mismatches on the WASI target
and the Emscripten target in a uniform way.

This is on top of #139642.


Patch is 22.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147076.diff

11 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsWebAssembly.def (+6)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+6)
  • (modified) clang/include/clang/Sema/SemaWasm.h (+1)
  • (modified) clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp (+58)
  • (modified) clang/lib/Sema/SemaWasm.cpp (+49)
  • (modified) llvm/include/llvm/IR/IntrinsicsWebAssembly.td (+4)
  • (modified) llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp (+2)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+114)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td (+13)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp (+75)
  • (modified) llvm/test/MC/WebAssembly/reference-types.s (+15)
diff --git a/clang/include/clang/Basic/BuiltinsWebAssembly.def b/clang/include/clang/Basic/BuiltinsWebAssembly.def
index ab480369b3820..85c6ddc125169 100644
--- a/clang/include/clang/Basic/BuiltinsWebAssembly.def
+++ b/clang/include/clang/Basic/BuiltinsWebAssembly.def
@@ -198,6 +198,12 @@ TARGET_BUILTIN(__builtin_wasm_ref_null_extern, "i", "nct", "reference-types")
 // return type.
 TARGET_BUILTIN(__builtin_wasm_ref_null_func, "i", "nct", "reference-types")
 
+// Check if the static type of a function pointer matches its static type. Used
+// to avoid "function signature mismatch" traps. Takes a function pointer, uses
+// table.get to look up the pointer in __indirect_function_table and then
+// ref.test to test the type.
+TARGET_BUILTIN(__builtin_wasm_test_function_pointer_signature, "i.", "nct", "reference-types")
+
 // Table builtins
 TARGET_BUILTIN(__builtin_wasm_table_set,  "viii", "t", "reference-types")
 TARGET_BUILTIN(__builtin_wasm_table_get,  "iii", "t", "reference-types")
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index e5a7cdc14a737..2ec556ee64fe5 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -7460,6 +7460,8 @@ def err_typecheck_illegal_increment_decrement : Error<
   "cannot %select{decrement|increment}1 value of type %0">;
 def err_typecheck_expect_int : Error<
   "used type %0 where integer is required">;
+def err_typecheck_expect_function_pointer
+    : Error<"used type %0 where function pointer is required">;
 def err_typecheck_expect_hlsl_resource : Error<
   "used type %0 where __hlsl_resource_t is required">;
 def err_typecheck_arithmetic_incomplete_or_sizeless_type : Error<
@@ -12995,6 +12997,10 @@ def err_wasm_builtin_arg_must_match_table_element_type : Error <
   "%ordinal0 argument must match the element type of the WebAssembly table in the %ordinal1 argument">;
 def err_wasm_builtin_arg_must_be_integer_type : Error <
   "%ordinal0 argument must be an integer">;
+def err_wasm_builtin_test_fp_sig_cannot_include_reference_type
+    : Error<"__builtin_wasm_test_function_pointer_signature not supported for "
+            "function pointers with reference types in their "
+            "%select{return|parameter}0 type">;
 
 // OpenACC diagnostics.
 def warn_acc_routine_unimplemented
diff --git a/clang/include/clang/Sema/SemaWasm.h b/clang/include/clang/Sema/SemaWasm.h
index 8841fdff23035..f97b72ff58579 100644
--- a/clang/include/clang/Sema/SemaWasm.h
+++ b/clang/include/clang/Sema/SemaWasm.h
@@ -36,6 +36,7 @@ class SemaWasm : public SemaBase {
   bool BuiltinWasmTableGrow(CallExpr *TheCall);
   bool BuiltinWasmTableFill(CallExpr *TheCall);
   bool BuiltinWasmTableCopy(CallExpr *TheCall);
+  bool BuiltinWasmTestFunctionPointerSignature(CallExpr *TheCall);
 
   WebAssemblyImportNameAttr *
   mergeImportNameAttr(Decl *D, const WebAssemblyImportNameAttr &AL);
diff --git a/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp b/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp
index 698f43215a1be..18ac488d8628c 100644
--- a/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp
@@ -12,7 +12,10 @@
 
 #include "CGBuiltin.h"
 #include "clang/Basic/TargetBuiltins.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/IntrinsicsWebAssembly.h"
+#include "llvm/Support/ErrorHandling.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -213,6 +216,61 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
     Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_ref_null_func);
     return Builder.CreateCall(Callee);
   }
+  case WebAssembly::BI__builtin_wasm_test_function_pointer_signature: {
+    Value *FuncRef = EmitScalarExpr(E->getArg(0));
+
+    // Get the function type from the argument's static type
+    QualType ArgType = E->getArg(0)->getType();
+    const PointerType *PtrTy = ArgType->getAs<PointerType>();
+    assert(PtrTy && "Sema should have ensured this is a function pointer");
+
+    const FunctionType *FuncTy = PtrTy->getPointeeType()->getAs<FunctionType>();
+    assert(FuncTy && "Sema should have ensured this is a function pointer");
+
+    // In the llvm IR, we won't have access anymore to the type of the function
+    // pointer so we need to insert this type information somehow. We gave the
+    // @llvm.wasm.ref.test.func varargs and here we add an extra 0 argument of
+    // the type corresponding to the type of each argument of the function
+    // signature. When we lower from the IR we'll use the types of these
+    // arguments to determine the signature we want to test for.
+
+    // Make a type index constant with 0. This gets replaced by the actual type
+    // in WebAssemblyMCInstLower.cpp.
+    llvm::FunctionType *LLVMFuncTy =
+        cast<llvm::FunctionType>(ConvertType(QualType(FuncTy, 0)));
+
+    uint NParams = LLVMFuncTy->getNumParams();
+    std::vector<Value *> Args;
+    Args.reserve(NParams + 1);
+    // The only real argument is the FuncRef
+    Args.push_back(FuncRef);
+
+    // Add the type information
+    auto addType = [this, &Args](llvm::Type *T) {
+      if (T->isVoidTy()) {
+        // Use TokenTy as dummy for void b/c the verifier rejects a
+        // void arg with 'Instruction operands must be first-class values!'
+        // TokenTy isn't a first class value either but apparently the verifier
+        // doesn't mind it.
+        Args.push_back(
+            UndefValue::get(llvm::Type::getTokenTy(getLLVMContext())));
+      } else if (T->isFloatingPointTy()) {
+        Args.push_back(ConstantFP::get(T, 0));
+      } else if (T->isIntegerTy()) {
+        Args.push_back(ConstantInt::get(T, 0));
+      } else {
+        // TODO: Handle reference types here. For now, we reject them in Sema.
+        llvm_unreachable("Unhandled type");
+      }
+    };
+
+    addType(LLVMFuncTy->getReturnType());
+    for (uint i = 0; i < NParams; i++) {
+      addType(LLVMFuncTy->getParamType(i));
+    }
+    Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_ref_test_func);
+    return Builder.CreateCall(Callee, Args);
+  }
   case WebAssembly::BI__builtin_wasm_swizzle_i8x16: {
     Value *Src = EmitScalarExpr(E->getArg(0));
     Value *Indices = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Sema/SemaWasm.cpp b/clang/lib/Sema/SemaWasm.cpp
index c0fa05bc17609..ca881550fad13 100644
--- a/clang/lib/Sema/SemaWasm.cpp
+++ b/clang/lib/Sema/SemaWasm.cpp
@@ -216,6 +216,53 @@ bool SemaWasm::BuiltinWasmTableCopy(CallExpr *TheCall) {
   return false;
 }
 
+bool SemaWasm::BuiltinWasmTestFunctionPointerSignature(CallExpr *TheCall) {
+  if (SemaRef.checkArgCount(TheCall, 1))
+    return true;
+
+  Expr *FuncPtrArg = TheCall->getArg(0);
+  QualType ArgType = FuncPtrArg->getType();
+
+  // Check that the argument is a function pointer
+  const PointerType *PtrTy = ArgType->getAs<PointerType>();
+  if (!PtrTy) {
+    return Diag(FuncPtrArg->getBeginLoc(),
+                diag::err_typecheck_expect_function_pointer)
+           << ArgType << FuncPtrArg->getSourceRange();
+  }
+
+  const FunctionProtoType *FuncTy =
+      PtrTy->getPointeeType()->getAs<FunctionProtoType>();
+  if (!FuncTy) {
+    return Diag(FuncPtrArg->getBeginLoc(),
+                diag::err_typecheck_expect_function_pointer)
+           << ArgType << FuncPtrArg->getSourceRange();
+  }
+
+  // Check that the function pointer doesn't use reference types
+  if (FuncTy->getReturnType().isWebAssemblyReferenceType()) {
+    return Diag(
+               FuncPtrArg->getBeginLoc(),
+               diag::err_wasm_builtin_test_fp_sig_cannot_include_reference_type)
+           << 0 << FuncTy->getReturnType() << FuncPtrArg->getSourceRange();
+  }
+  auto NParams = FuncTy->getNumParams();
+  for (unsigned I = 0; I < NParams; I++) {
+    if (FuncTy->getParamType(I).isWebAssemblyReferenceType()) {
+      return Diag(
+                 FuncPtrArg->getBeginLoc(),
+                 diag::
+                     err_wasm_builtin_test_fp_sig_cannot_include_reference_type)
+             << 1 << FuncPtrArg->getSourceRange();
+    }
+  }
+
+  // Set return type to int (the result of the test)
+  TheCall->setType(getASTContext().IntTy);
+
+  return false;
+}
+
 bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const TargetInfo &TI,
                                                    unsigned BuiltinID,
                                                    CallExpr *TheCall) {
@@ -236,6 +283,8 @@ bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const TargetInfo &TI,
     return BuiltinWasmTableFill(TheCall);
   case WebAssembly::BI__builtin_wasm_table_copy:
     return BuiltinWasmTableCopy(TheCall);
+  case WebAssembly::BI__builtin_wasm_test_function_pointer_signature:
+    return BuiltinWasmTestFunctionPointerSignature(TheCall);
   }
 
   return false;
diff --git a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
index f592ff287a0e3..fb61d8a11e5c0 100644
--- a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
+++ b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
@@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
   DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
                         "llvm.wasm.ref.is_null.exn">;
 
+def int_wasm_ref_test_func
+    : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
+                            [IntrNoMem], "llvm.wasm.ref.test.func">;
+
 //===----------------------------------------------------------------------===//
 // Table intrinsics
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
index 7ee6a3d8304be..c1b3936c1dcec 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
@@ -668,6 +668,8 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
       if (parseFunctionTableOperand(&FunctionTable))
         return true;
       ExpectFuncType = true;
+    } else if (Name == "ref.test") {
+      ExpectFuncType = true;
     }
 
     // Returns true if the next tokens are a catch clause
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index aac68b32da13a..1024bb69e6a49 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -18,6 +18,7 @@
 #include "WebAssemblySubtarget.h"
 #include "WebAssemblyTargetMachine.h"
 #include "WebAssemblyUtilities.h"
+#include "llvm/BinaryFormat/Wasm.h"
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -505,6 +506,51 @@ MVT WebAssemblyTargetLowering::getScalarShiftAmountTy(const DataLayout & /*DL*/,
   return Result;
 }
 
+static MachineBasicBlock *LowerRefTestFuncRef(MachineInstr &MI, DebugLoc DL,
+                                              MachineBasicBlock *BB,
+                                              const TargetInstrInfo &TII) {
+  // Lower a REF_TEST_FUNCREF_PSEUDO instruction into a REF_TEST_FUNCREF
+  // instruction by combining the signature info Imm operands that
+  // SelectionDag/InstrEmitter.cpp makes into one CImm operand. Put this into
+  // the type index placeholder for REF_TEST_FUNCREF
+  Register ResultReg = MI.getOperand(0).getReg();
+  Register FuncRefReg = MI.getOperand(1).getReg();
+
+  auto NParams = MI.getNumOperands() - 3;
+  auto Sig = APInt(NParams * 64, 0);
+
+  {
+    uint64_t V = MI.getOperand(2).getImm();
+    Sig |= int64_t(V);
+  }
+
+  for (unsigned I = 3; I < MI.getNumOperands(); I++) {
+    const MachineOperand &MO = MI.getOperand(I);
+    if (!MO.isImm()) {
+      // I'm not really sure what these are or where they come from but it seems
+      // to be okay to ignore them
+      continue;
+    }
+    uint16_t V = MO.getImm();
+    Sig <<= 64;
+    Sig |= int64_t(V);
+  }
+
+  ConstantInt *TypeInfo =
+      ConstantInt::get(BB->getParent()->getFunction().getContext(), Sig);
+
+  // Put the type info first in the placeholder for the type index, then the
+  // actual funcref arg
+  BuildMI(*BB, MI, DL, TII.get(WebAssembly::REF_TEST_FUNCREF), ResultReg)
+      .addCImm(TypeInfo)
+      .addReg(FuncRefReg);
+
+  // Remove the original instruction
+  MI.eraseFromParent();
+
+  return BB;
+}
+
 // Lower an fp-to-int conversion operator from the LLVM opcode, which has an
 // undefined result on invalid/overflow, to the WebAssembly opcode, which
 // traps on invalid/overflow.
@@ -866,6 +912,8 @@ MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter(
   switch (MI.getOpcode()) {
   default:
     llvm_unreachable("Unexpected instr type to insert");
+  case WebAssembly::REF_TEST_FUNCREF_PSEUDO:
+    return LowerRefTestFuncRef(MI, DL, BB, TII);
   case WebAssembly::FP_TO_SINT_I32_F32:
     return LowerFPToInt(MI, DL, BB, TII, false, false, false,
                         WebAssembly::I32_TRUNC_S_F32);
@@ -2260,6 +2308,72 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
                            DAG.getTargetExternalSymbol(TlsBase, PtrVT)),
         0);
   }
+  case Intrinsic::wasm_ref_test_func: {
+    // First emit the TABLE_GET instruction to convert function pointer ==>
+    // funcref
+    MachineFunction &MF = DAG.getMachineFunction();
+    auto PtrVT = getPointerTy(MF.getDataLayout());
+    MCSymbol *Table =
+        WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget);
+    SDValue TableSym = DAG.getMCSymbol(Table, PtrVT);
+    SDValue FuncRef =
+        SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
+                                   MVT::funcref, TableSym, Op.getOperand(1)),
+                0);
+
+    SmallVector<SDValue, 4> Ops;
+    Ops.push_back(FuncRef);
+
+    // We want to encode the type information into an APInt which we'll put
+    // in a CImm. However, in SelectionDag/InstrEmitter.cpp there is no code
+    // path that emits a CImm. So we need a custom inserter to put it in.
+
+    // We'll put each type argument in a separate TargetConstant which gets
+    // lowered to a MachineInstruction Imm. We combine these into a CImm in our
+    // custom inserter because it creates a problem downstream to have all these
+    // extra immediates.
+    {
+      SDValue Operand = Op.getOperand(2);
+      MVT VT = Operand.getValueType().getSimpleVT();
+      WebAssembly::BlockType V;
+      if (VT == MVT::Untyped) {
+        V = WebAssembly::BlockType::Void;
+      } else if (VT == MVT::i32) {
+        V = WebAssembly::BlockType::I32;
+      } else if (VT == MVT::i64) {
+        V = WebAssembly::BlockType::I64;
+      } else if (VT == MVT::f32) {
+        V = WebAssembly::BlockType::F32;
+      } else if (VT == MVT::f64) {
+        V = WebAssembly::BlockType::F64;
+      } else {
+        llvm_unreachable("Unhandled type!");
+      }
+      Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64));
+    }
+
+    for (unsigned i = 3; i < Op.getNumOperands(); ++i) {
+      SDValue Operand = Op.getOperand(i);
+      MVT VT = Operand.getValueType().getSimpleVT();
+      wasm::ValType V;
+      if (VT == MVT::i32) {
+        V = wasm::ValType::I32;
+      } else if (VT == MVT::i64) {
+        V = wasm::ValType::I64;
+      } else if (VT == MVT::f32) {
+        V = wasm::ValType::F32;
+      } else if (VT == MVT::f64) {
+        V = wasm::ValType::F64;
+      } else {
+        llvm_unreachable("Unhandled type!");
+      }
+      Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64));
+    }
+
+    return SDValue(DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF_PSEUDO, DL,
+                                      MVT::i32, Ops),
+                   0);
+  }
   }
 }
 
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td
index 2654a09387fd4..0c61f5770e748 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td
@@ -36,6 +36,19 @@ multiclass REF_I<WebAssemblyRegClass rc, ValueType vt, string ht> {
         Requires<[HasReferenceTypes]>;
 }
 
+let usesCustomInserter = 1, isPseudo = 1 in defm REF_TEST_FUNCREF_PSEUDO
+    : I<(outs I32:$res), (ins TypeIndex:$type, FUNCREF:$ref, variable_ops),
+        (outs), (ins TypeIndex:$type), [], "ref.test.pseudo\t$type, $ref",
+        "ref.test.pseudo $type", -1>;
+
+defm REF_TEST_FUNCREF :
+  I<(outs I32: $res),
+    (ins TypeIndex:$type, FUNCREF: $ref),
+    (outs),
+    (ins TypeIndex:$type),
+    [],
+    "ref.test\t$type, $ref", "ref.test $type", 0xfb14>;
+
 defm "" : REF_I<FUNCREF, funcref, "func">;
 defm "" : REF_I<EXTERNREF, externref, "extern">;
 defm "" : REF_I<EXNREF, exnref, "exn">;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
index 8c8629203bca7..770942d09e429 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
@@ -15,13 +15,17 @@
 #include "WebAssemblyMCInstLower.h"
 #include "MCTargetDesc/WebAssemblyMCExpr.h"
 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
+#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
 #include "TargetInfo/WebAssemblyTargetInfo.h"
 #include "Utils/WebAssemblyTypeUtilities.h"
 #include "WebAssemblyAsmPrinter.h"
 #include "WebAssemblyMachineFunctionInfo.h"
 #include "WebAssemblyUtilities.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/BinaryFormat/Wasm.h"
 #include "llvm/CodeGen/AsmPrinter.h"
 #include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineOperand.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/MC/MCAsmInfo.h"
 #include "llvm/MC/MCContext.h"
@@ -198,11 +202,81 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
       MCOp = MCOperand::createReg(WAReg);
       break;
     }
+    case llvm::MachineOperand::MO_CImmediate: {
+      // Lower type index placeholder for ref.test
+      // Currently this is the only way that CImmediates show up so panic if we
+      // get confused.
+      unsigned DescIndex = I - NumVariadicDefs;
+      if (DescIndex >= Desc.NumOperands) {
+        llvm_unreachable("unexpected CImmediate operand");
+      }
+      const MCOperandInfo &Info = Desc.operands()[DescIndex];
+      if (Info.OperandType != WebAssembly::OPERAND_TYPEINDEX) {
+        llvm_unreachable("unexpected CImmediate operand");
+      }
+      auto CImm = MO.getCImm()->getValue();
+      auto NumWords = CImm.getNumWords();
+      // Extract the type data we packed into the CImm in LowerRefTestFuncRef.
+      // We need to load the words from most significant to least significant
+      // order because of the way we bitshifted them in from the right.
+      // The return type needs special handling because it could be void.
+      auto ReturnType = static_cast<WebAssembly::BlockType>(
+          CImm.extractBitsAsZExtValue(64, (NumWords - 1) * 64));
+      assert(ReturnType != WebAssembly::BlockType::Invalid);
+      SmallVector<wasm::ValType, 2> Returns;
+      switch (ReturnType) {
+      case WebAssembly::BlockType::Invalid:
+        llvm_unreachable("Invalid return type");
+      case WebAssembly::BlockType::I32:
+        Returns = {wasm::ValType::I32};
+        break;
+      case WebAssembly::BlockType::I64:
+        Returns = {wasm::ValType::I64};
+        break;
+      case WebAssembly::BlockType::F32:
+        Returns = {wasm::ValType::F32};
+        break;
+      case WebAssembly::BlockType::F64:
+        Returns = {wasm::ValType::F64};
+        break;
+      case WebAssembly::BlockType::Void:
+        Returns = {};
+        break;
+      case WebAssembly::BlockType::Exnref:
+        Returns = {wasm::ValType::EXNREF};
+        break;
+      case WebAssembly::BlockType::Externref:
+        Returns = {wasm::ValType::EXTERNREF};
+        break;
+      case WebAssembly::BlockType::Funcref:
+        Returns = {wasm::ValType::FUNCREF};
+        break;
+      case WebAssembly::BlockType::V128:
+        Returns = {wasm::ValType::V128};
+        break;
+      case WebAssembly::BlockType::Multivalue: {
+        llvm_unreachabl...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2025

@llvm/pr-subscribers-mc

Author: Hood Chatham (hoodmane)

Changes

This uses ref.test to check whether the function pointer's runtime type
matches its static type. If so, then calling it won't trap with "indirect
call signature mismatch". This would be very useful here:
https://github.com/python/cpython/blob/main/Python/emscripten_trampoline.c
and would allow us to fix function pointer mismatches on the WASI target
and the Emscripten target in a uniform way.

This is on top of #139642.


Patch is 22.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147076.diff

11 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsWebAssembly.def (+6)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+6)
  • (modified) clang/include/clang/Sema/SemaWasm.h (+1)
  • (modified) clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp (+58)
  • (modified) clang/lib/Sema/SemaWasm.cpp (+49)
  • (modified) llvm/include/llvm/IR/IntrinsicsWebAssembly.td (+4)
  • (modified) llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp (+2)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+114)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td (+13)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp (+75)
  • (modified) llvm/test/MC/WebAssembly/reference-types.s (+15)
diff --git a/clang/include/clang/Basic/BuiltinsWebAssembly.def b/clang/include/clang/Basic/BuiltinsWebAssembly.def
index ab480369b3820..85c6ddc125169 100644
--- a/clang/include/clang/Basic/BuiltinsWebAssembly.def
+++ b/clang/include/clang/Basic/BuiltinsWebAssembly.def
@@ -198,6 +198,12 @@ TARGET_BUILTIN(__builtin_wasm_ref_null_extern, "i", "nct", "reference-types")
 // return type.
 TARGET_BUILTIN(__builtin_wasm_ref_null_func, "i", "nct", "reference-types")
 
+// Check if the static type of a function pointer matches its static type. Used
+// to avoid "function signature mismatch" traps. Takes a function pointer, uses
+// table.get to look up the pointer in __indirect_function_table and then
+// ref.test to test the type.
+TARGET_BUILTIN(__builtin_wasm_test_function_pointer_signature, "i.", "nct", "reference-types")
+
 // Table builtins
 TARGET_BUILTIN(__builtin_wasm_table_set,  "viii", "t", "reference-types")
 TARGET_BUILTIN(__builtin_wasm_table_get,  "iii", "t", "reference-types")
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index e5a7cdc14a737..2ec556ee64fe5 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -7460,6 +7460,8 @@ def err_typecheck_illegal_increment_decrement : Error<
   "cannot %select{decrement|increment}1 value of type %0">;
 def err_typecheck_expect_int : Error<
   "used type %0 where integer is required">;
+def err_typecheck_expect_function_pointer
+    : Error<"used type %0 where function pointer is required">;
 def err_typecheck_expect_hlsl_resource : Error<
   "used type %0 where __hlsl_resource_t is required">;
 def err_typecheck_arithmetic_incomplete_or_sizeless_type : Error<
@@ -12995,6 +12997,10 @@ def err_wasm_builtin_arg_must_match_table_element_type : Error <
   "%ordinal0 argument must match the element type of the WebAssembly table in the %ordinal1 argument">;
 def err_wasm_builtin_arg_must_be_integer_type : Error <
   "%ordinal0 argument must be an integer">;
+def err_wasm_builtin_test_fp_sig_cannot_include_reference_type
+    : Error<"__builtin_wasm_test_function_pointer_signature not supported for "
+            "function pointers with reference types in their "
+            "%select{return|parameter}0 type">;
 
 // OpenACC diagnostics.
 def warn_acc_routine_unimplemented
diff --git a/clang/include/clang/Sema/SemaWasm.h b/clang/include/clang/Sema/SemaWasm.h
index 8841fdff23035..f97b72ff58579 100644
--- a/clang/include/clang/Sema/SemaWasm.h
+++ b/clang/include/clang/Sema/SemaWasm.h
@@ -36,6 +36,7 @@ class SemaWasm : public SemaBase {
   bool BuiltinWasmTableGrow(CallExpr *TheCall);
   bool BuiltinWasmTableFill(CallExpr *TheCall);
   bool BuiltinWasmTableCopy(CallExpr *TheCall);
+  bool BuiltinWasmTestFunctionPointerSignature(CallExpr *TheCall);
 
   WebAssemblyImportNameAttr *
   mergeImportNameAttr(Decl *D, const WebAssemblyImportNameAttr &AL);
diff --git a/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp b/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp
index 698f43215a1be..18ac488d8628c 100644
--- a/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp
@@ -12,7 +12,10 @@
 
 #include "CGBuiltin.h"
 #include "clang/Basic/TargetBuiltins.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/IntrinsicsWebAssembly.h"
+#include "llvm/Support/ErrorHandling.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -213,6 +216,61 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
     Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_ref_null_func);
     return Builder.CreateCall(Callee);
   }
+  case WebAssembly::BI__builtin_wasm_test_function_pointer_signature: {
+    Value *FuncRef = EmitScalarExpr(E->getArg(0));
+
+    // Get the function type from the argument's static type
+    QualType ArgType = E->getArg(0)->getType();
+    const PointerType *PtrTy = ArgType->getAs<PointerType>();
+    assert(PtrTy && "Sema should have ensured this is a function pointer");
+
+    const FunctionType *FuncTy = PtrTy->getPointeeType()->getAs<FunctionType>();
+    assert(FuncTy && "Sema should have ensured this is a function pointer");
+
+    // In the llvm IR, we won't have access anymore to the type of the function
+    // pointer so we need to insert this type information somehow. We gave the
+    // @llvm.wasm.ref.test.func varargs and here we add an extra 0 argument of
+    // the type corresponding to the type of each argument of the function
+    // signature. When we lower from the IR we'll use the types of these
+    // arguments to determine the signature we want to test for.
+
+    // Make a type index constant with 0. This gets replaced by the actual type
+    // in WebAssemblyMCInstLower.cpp.
+    llvm::FunctionType *LLVMFuncTy =
+        cast<llvm::FunctionType>(ConvertType(QualType(FuncTy, 0)));
+
+    uint NParams = LLVMFuncTy->getNumParams();
+    std::vector<Value *> Args;
+    Args.reserve(NParams + 1);
+    // The only real argument is the FuncRef
+    Args.push_back(FuncRef);
+
+    // Add the type information
+    auto addType = [this, &Args](llvm::Type *T) {
+      if (T->isVoidTy()) {
+        // Use TokenTy as dummy for void b/c the verifier rejects a
+        // void arg with 'Instruction operands must be first-class values!'
+        // TokenTy isn't a first class value either but apparently the verifier
+        // doesn't mind it.
+        Args.push_back(
+            UndefValue::get(llvm::Type::getTokenTy(getLLVMContext())));
+      } else if (T->isFloatingPointTy()) {
+        Args.push_back(ConstantFP::get(T, 0));
+      } else if (T->isIntegerTy()) {
+        Args.push_back(ConstantInt::get(T, 0));
+      } else {
+        // TODO: Handle reference types here. For now, we reject them in Sema.
+        llvm_unreachable("Unhandled type");
+      }
+    };
+
+    addType(LLVMFuncTy->getReturnType());
+    for (uint i = 0; i < NParams; i++) {
+      addType(LLVMFuncTy->getParamType(i));
+    }
+    Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_ref_test_func);
+    return Builder.CreateCall(Callee, Args);
+  }
   case WebAssembly::BI__builtin_wasm_swizzle_i8x16: {
     Value *Src = EmitScalarExpr(E->getArg(0));
     Value *Indices = EmitScalarExpr(E->getArg(1));
diff --git a/clang/lib/Sema/SemaWasm.cpp b/clang/lib/Sema/SemaWasm.cpp
index c0fa05bc17609..ca881550fad13 100644
--- a/clang/lib/Sema/SemaWasm.cpp
+++ b/clang/lib/Sema/SemaWasm.cpp
@@ -216,6 +216,53 @@ bool SemaWasm::BuiltinWasmTableCopy(CallExpr *TheCall) {
   return false;
 }
 
+bool SemaWasm::BuiltinWasmTestFunctionPointerSignature(CallExpr *TheCall) {
+  if (SemaRef.checkArgCount(TheCall, 1))
+    return true;
+
+  Expr *FuncPtrArg = TheCall->getArg(0);
+  QualType ArgType = FuncPtrArg->getType();
+
+  // Check that the argument is a function pointer
+  const PointerType *PtrTy = ArgType->getAs<PointerType>();
+  if (!PtrTy) {
+    return Diag(FuncPtrArg->getBeginLoc(),
+                diag::err_typecheck_expect_function_pointer)
+           << ArgType << FuncPtrArg->getSourceRange();
+  }
+
+  const FunctionProtoType *FuncTy =
+      PtrTy->getPointeeType()->getAs<FunctionProtoType>();
+  if (!FuncTy) {
+    return Diag(FuncPtrArg->getBeginLoc(),
+                diag::err_typecheck_expect_function_pointer)
+           << ArgType << FuncPtrArg->getSourceRange();
+  }
+
+  // Check that the function pointer doesn't use reference types
+  if (FuncTy->getReturnType().isWebAssemblyReferenceType()) {
+    return Diag(
+               FuncPtrArg->getBeginLoc(),
+               diag::err_wasm_builtin_test_fp_sig_cannot_include_reference_type)
+           << 0 << FuncTy->getReturnType() << FuncPtrArg->getSourceRange();
+  }
+  auto NParams = FuncTy->getNumParams();
+  for (unsigned I = 0; I < NParams; I++) {
+    if (FuncTy->getParamType(I).isWebAssemblyReferenceType()) {
+      return Diag(
+                 FuncPtrArg->getBeginLoc(),
+                 diag::
+                     err_wasm_builtin_test_fp_sig_cannot_include_reference_type)
+             << 1 << FuncPtrArg->getSourceRange();
+    }
+  }
+
+  // Set return type to int (the result of the test)
+  TheCall->setType(getASTContext().IntTy);
+
+  return false;
+}
+
 bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const TargetInfo &TI,
                                                    unsigned BuiltinID,
                                                    CallExpr *TheCall) {
@@ -236,6 +283,8 @@ bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const TargetInfo &TI,
     return BuiltinWasmTableFill(TheCall);
   case WebAssembly::BI__builtin_wasm_table_copy:
     return BuiltinWasmTableCopy(TheCall);
+  case WebAssembly::BI__builtin_wasm_test_function_pointer_signature:
+    return BuiltinWasmTestFunctionPointerSignature(TheCall);
   }
 
   return false;
diff --git a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
index f592ff287a0e3..fb61d8a11e5c0 100644
--- a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
+++ b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
@@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
   DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
                         "llvm.wasm.ref.is_null.exn">;
 
+def int_wasm_ref_test_func
+    : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
+                            [IntrNoMem], "llvm.wasm.ref.test.func">;
+
 //===----------------------------------------------------------------------===//
 // Table intrinsics
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
index 7ee6a3d8304be..c1b3936c1dcec 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp
@@ -668,6 +668,8 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
       if (parseFunctionTableOperand(&FunctionTable))
         return true;
       ExpectFuncType = true;
+    } else if (Name == "ref.test") {
+      ExpectFuncType = true;
     }
 
     // Returns true if the next tokens are a catch clause
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index aac68b32da13a..1024bb69e6a49 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -18,6 +18,7 @@
 #include "WebAssemblySubtarget.h"
 #include "WebAssemblyTargetMachine.h"
 #include "WebAssemblyUtilities.h"
+#include "llvm/BinaryFormat/Wasm.h"
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -505,6 +506,51 @@ MVT WebAssemblyTargetLowering::getScalarShiftAmountTy(const DataLayout & /*DL*/,
   return Result;
 }
 
+static MachineBasicBlock *LowerRefTestFuncRef(MachineInstr &MI, DebugLoc DL,
+                                              MachineBasicBlock *BB,
+                                              const TargetInstrInfo &TII) {
+  // Lower a REF_TEST_FUNCREF_PSEUDO instruction into a REF_TEST_FUNCREF
+  // instruction by combining the signature info Imm operands that
+  // SelectionDag/InstrEmitter.cpp makes into one CImm operand. Put this into
+  // the type index placeholder for REF_TEST_FUNCREF
+  Register ResultReg = MI.getOperand(0).getReg();
+  Register FuncRefReg = MI.getOperand(1).getReg();
+
+  auto NParams = MI.getNumOperands() - 3;
+  auto Sig = APInt(NParams * 64, 0);
+
+  {
+    uint64_t V = MI.getOperand(2).getImm();
+    Sig |= int64_t(V);
+  }
+
+  for (unsigned I = 3; I < MI.getNumOperands(); I++) {
+    const MachineOperand &MO = MI.getOperand(I);
+    if (!MO.isImm()) {
+      // I'm not really sure what these are or where they come from but it seems
+      // to be okay to ignore them
+      continue;
+    }
+    uint16_t V = MO.getImm();
+    Sig <<= 64;
+    Sig |= int64_t(V);
+  }
+
+  ConstantInt *TypeInfo =
+      ConstantInt::get(BB->getParent()->getFunction().getContext(), Sig);
+
+  // Put the type info first in the placeholder for the type index, then the
+  // actual funcref arg
+  BuildMI(*BB, MI, DL, TII.get(WebAssembly::REF_TEST_FUNCREF), ResultReg)
+      .addCImm(TypeInfo)
+      .addReg(FuncRefReg);
+
+  // Remove the original instruction
+  MI.eraseFromParent();
+
+  return BB;
+}
+
 // Lower an fp-to-int conversion operator from the LLVM opcode, which has an
 // undefined result on invalid/overflow, to the WebAssembly opcode, which
 // traps on invalid/overflow.
@@ -866,6 +912,8 @@ MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter(
   switch (MI.getOpcode()) {
   default:
     llvm_unreachable("Unexpected instr type to insert");
+  case WebAssembly::REF_TEST_FUNCREF_PSEUDO:
+    return LowerRefTestFuncRef(MI, DL, BB, TII);
   case WebAssembly::FP_TO_SINT_I32_F32:
     return LowerFPToInt(MI, DL, BB, TII, false, false, false,
                         WebAssembly::I32_TRUNC_S_F32);
@@ -2260,6 +2308,72 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
                            DAG.getTargetExternalSymbol(TlsBase, PtrVT)),
         0);
   }
+  case Intrinsic::wasm_ref_test_func: {
+    // First emit the TABLE_GET instruction to convert function pointer ==>
+    // funcref
+    MachineFunction &MF = DAG.getMachineFunction();
+    auto PtrVT = getPointerTy(MF.getDataLayout());
+    MCSymbol *Table =
+        WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget);
+    SDValue TableSym = DAG.getMCSymbol(Table, PtrVT);
+    SDValue FuncRef =
+        SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
+                                   MVT::funcref, TableSym, Op.getOperand(1)),
+                0);
+
+    SmallVector<SDValue, 4> Ops;
+    Ops.push_back(FuncRef);
+
+    // We want to encode the type information into an APInt which we'll put
+    // in a CImm. However, in SelectionDag/InstrEmitter.cpp there is no code
+    // path that emits a CImm. So we need a custom inserter to put it in.
+
+    // We'll put each type argument in a separate TargetConstant which gets
+    // lowered to a MachineInstruction Imm. We combine these into a CImm in our
+    // custom inserter because it creates a problem downstream to have all these
+    // extra immediates.
+    {
+      SDValue Operand = Op.getOperand(2);
+      MVT VT = Operand.getValueType().getSimpleVT();
+      WebAssembly::BlockType V;
+      if (VT == MVT::Untyped) {
+        V = WebAssembly::BlockType::Void;
+      } else if (VT == MVT::i32) {
+        V = WebAssembly::BlockType::I32;
+      } else if (VT == MVT::i64) {
+        V = WebAssembly::BlockType::I64;
+      } else if (VT == MVT::f32) {
+        V = WebAssembly::BlockType::F32;
+      } else if (VT == MVT::f64) {
+        V = WebAssembly::BlockType::F64;
+      } else {
+        llvm_unreachable("Unhandled type!");
+      }
+      Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64));
+    }
+
+    for (unsigned i = 3; i < Op.getNumOperands(); ++i) {
+      SDValue Operand = Op.getOperand(i);
+      MVT VT = Operand.getValueType().getSimpleVT();
+      wasm::ValType V;
+      if (VT == MVT::i32) {
+        V = wasm::ValType::I32;
+      } else if (VT == MVT::i64) {
+        V = wasm::ValType::I64;
+      } else if (VT == MVT::f32) {
+        V = wasm::ValType::F32;
+      } else if (VT == MVT::f64) {
+        V = wasm::ValType::F64;
+      } else {
+        llvm_unreachable("Unhandled type!");
+      }
+      Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64));
+    }
+
+    return SDValue(DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF_PSEUDO, DL,
+                                      MVT::i32, Ops),
+                   0);
+  }
   }
 }
 
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td
index 2654a09387fd4..0c61f5770e748 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td
@@ -36,6 +36,19 @@ multiclass REF_I<WebAssemblyRegClass rc, ValueType vt, string ht> {
         Requires<[HasReferenceTypes]>;
 }
 
+let usesCustomInserter = 1, isPseudo = 1 in defm REF_TEST_FUNCREF_PSEUDO
+    : I<(outs I32:$res), (ins TypeIndex:$type, FUNCREF:$ref, variable_ops),
+        (outs), (ins TypeIndex:$type), [], "ref.test.pseudo\t$type, $ref",
+        "ref.test.pseudo $type", -1>;
+
+defm REF_TEST_FUNCREF :
+  I<(outs I32: $res),
+    (ins TypeIndex:$type, FUNCREF: $ref),
+    (outs),
+    (ins TypeIndex:$type),
+    [],
+    "ref.test\t$type, $ref", "ref.test $type", 0xfb14>;
+
 defm "" : REF_I<FUNCREF, funcref, "func">;
 defm "" : REF_I<EXTERNREF, externref, "extern">;
 defm "" : REF_I<EXNREF, exnref, "exn">;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
index 8c8629203bca7..770942d09e429 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp
@@ -15,13 +15,17 @@
 #include "WebAssemblyMCInstLower.h"
 #include "MCTargetDesc/WebAssemblyMCExpr.h"
 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
+#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
 #include "TargetInfo/WebAssemblyTargetInfo.h"
 #include "Utils/WebAssemblyTypeUtilities.h"
 #include "WebAssemblyAsmPrinter.h"
 #include "WebAssemblyMachineFunctionInfo.h"
 #include "WebAssemblyUtilities.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/BinaryFormat/Wasm.h"
 #include "llvm/CodeGen/AsmPrinter.h"
 #include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineOperand.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/MC/MCAsmInfo.h"
 #include "llvm/MC/MCContext.h"
@@ -198,11 +202,81 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI,
       MCOp = MCOperand::createReg(WAReg);
       break;
     }
+    case llvm::MachineOperand::MO_CImmediate: {
+      // Lower type index placeholder for ref.test
+      // Currently this is the only way that CImmediates show up so panic if we
+      // get confused.
+      unsigned DescIndex = I - NumVariadicDefs;
+      if (DescIndex >= Desc.NumOperands) {
+        llvm_unreachable("unexpected CImmediate operand");
+      }
+      const MCOperandInfo &Info = Desc.operands()[DescIndex];
+      if (Info.OperandType != WebAssembly::OPERAND_TYPEINDEX) {
+        llvm_unreachable("unexpected CImmediate operand");
+      }
+      auto CImm = MO.getCImm()->getValue();
+      auto NumWords = CImm.getNumWords();
+      // Extract the type data we packed into the CImm in LowerRefTestFuncRef.
+      // We need to load the words from most significant to least significant
+      // order because of the way we bitshifted them in from the right.
+      // The return type needs special handling because it could be void.
+      auto ReturnType = static_cast<WebAssembly::BlockType>(
+          CImm.extractBitsAsZExtValue(64, (NumWords - 1) * 64));
+      assert(ReturnType != WebAssembly::BlockType::Invalid);
+      SmallVector<wasm::ValType, 2> Returns;
+      switch (ReturnType) {
+      case WebAssembly::BlockType::Invalid:
+        llvm_unreachable("Invalid return type");
+      case WebAssembly::BlockType::I32:
+        Returns = {wasm::ValType::I32};
+        break;
+      case WebAssembly::BlockType::I64:
+        Returns = {wasm::ValType::I64};
+        break;
+      case WebAssembly::BlockType::F32:
+        Returns = {wasm::ValType::F32};
+        break;
+      case WebAssembly::BlockType::F64:
+        Returns = {wasm::ValType::F64};
+        break;
+      case WebAssembly::BlockType::Void:
+        Returns = {};
+        break;
+      case WebAssembly::BlockType::Exnref:
+        Returns = {wasm::ValType::EXNREF};
+        break;
+      case WebAssembly::BlockType::Externref:
+        Returns = {wasm::ValType::EXTERNREF};
+        break;
+      case WebAssembly::BlockType::Funcref:
+        Returns = {wasm::ValType::FUNCREF};
+        break;
+      case WebAssembly::BlockType::V128:
+        Returns = {wasm::ValType::V128};
+        break;
+      case WebAssembly::BlockType::Multivalue: {
+        llvm_unreachabl...
[truncated]

Copy link

github-actions bot commented Jul 4, 2025

✅ With the latest revision this PR passed the undef deprecator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:WebAssembly clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category llvm:ir mc Machine (object) code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants