diff --git a/clang/include/clang/Basic/BuiltinsWebAssembly.def b/clang/include/clang/Basic/BuiltinsWebAssembly.def index ab480369b3820..85c6ddc125169 100644 --- a/clang/include/clang/Basic/BuiltinsWebAssembly.def +++ b/clang/include/clang/Basic/BuiltinsWebAssembly.def @@ -198,6 +198,12 @@ TARGET_BUILTIN(__builtin_wasm_ref_null_extern, "i", "nct", "reference-types") // return type. TARGET_BUILTIN(__builtin_wasm_ref_null_func, "i", "nct", "reference-types") +// Check if the static type of a function pointer matches its static type. Used +// to avoid "function signature mismatch" traps. Takes a function pointer, uses +// table.get to look up the pointer in __indirect_function_table and then +// ref.test to test the type. +TARGET_BUILTIN(__builtin_wasm_test_function_pointer_signature, "i.", "nct", "reference-types") + // Table builtins TARGET_BUILTIN(__builtin_wasm_table_set, "viii", "t", "reference-types") TARGET_BUILTIN(__builtin_wasm_table_get, "iii", "t", "reference-types") diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index e5a7cdc14a737..8e9af3d5dbfac 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -7460,6 +7460,8 @@ def err_typecheck_illegal_increment_decrement : Error< "cannot %select{decrement|increment}1 value of type %0">; def err_typecheck_expect_int : Error< "used type %0 where integer is required">; +def err_typecheck_expect_function_pointer + : Error<"used type %0 where function pointer is required">; def err_typecheck_expect_hlsl_resource : Error< "used type %0 where __hlsl_resource_t is required">; def err_typecheck_arithmetic_incomplete_or_sizeless_type : Error< @@ -12995,6 +12997,10 @@ def err_wasm_builtin_arg_must_match_table_element_type : Error < "%ordinal0 argument must match the element type of the WebAssembly table in the %ordinal1 argument">; def err_wasm_builtin_arg_must_be_integer_type : Error < "%ordinal0 argument must be an integer">; +def err_wasm_builtin_test_fp_sig_cannot_include_reference_type + : Error<"not supported for " + "function pointers with a reference type %select{return " + "value|parameter}0">; // OpenACC diagnostics. def warn_acc_routine_unimplemented diff --git a/clang/include/clang/Sema/SemaWasm.h b/clang/include/clang/Sema/SemaWasm.h index 8841fdff23035..f97b72ff58579 100644 --- a/clang/include/clang/Sema/SemaWasm.h +++ b/clang/include/clang/Sema/SemaWasm.h @@ -36,6 +36,7 @@ class SemaWasm : public SemaBase { bool BuiltinWasmTableGrow(CallExpr *TheCall); bool BuiltinWasmTableFill(CallExpr *TheCall); bool BuiltinWasmTableCopy(CallExpr *TheCall); + bool BuiltinWasmTestFunctionPointerSignature(CallExpr *TheCall); WebAssemblyImportNameAttr * mergeImportNameAttr(Decl *D, const WebAssemblyImportNameAttr &AL); diff --git a/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp b/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp index 698f43215a1be..56231a3a357c9 100644 --- a/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp +++ b/clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp @@ -12,7 +12,10 @@ #include "CGBuiltin.h" #include "clang/Basic/TargetBuiltins.h" +#include "llvm/ADT/APInt.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/IntrinsicsWebAssembly.h" +#include "llvm/Support/ErrorHandling.h" using namespace clang; using namespace CodeGen; @@ -213,6 +216,61 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID, Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_ref_null_func); return Builder.CreateCall(Callee); } + case WebAssembly::BI__builtin_wasm_test_function_pointer_signature: { + Value *FuncRef = EmitScalarExpr(E->getArg(0)); + + // Get the function type from the argument's static type + QualType ArgType = E->getArg(0)->getType(); + const PointerType *PtrTy = ArgType->getAs(); + assert(PtrTy && "Sema should have ensured this is a function pointer"); + + const FunctionType *FuncTy = PtrTy->getPointeeType()->getAs(); + assert(FuncTy && "Sema should have ensured this is a function pointer"); + + // In the llvm IR, we won't have access anymore to the type of the function + // pointer so we need to insert this type information somehow. We gave the + // @llvm.wasm.ref.test.func varargs and here we add an extra 0 argument of + // the type corresponding to the type of each argument of the function + // signature. When we lower from the IR we'll use the types of these + // arguments to determine the signature we want to test for. + + // Make a type index constant with 0. This gets replaced by the actual type + // in WebAssemblyMCInstLower.cpp. + llvm::FunctionType *LLVMFuncTy = + cast(ConvertType(QualType(FuncTy, 0))); + + uint NParams = LLVMFuncTy->getNumParams(); + std::vector Args; + Args.reserve(NParams + 1); + // The only real argument is the FuncRef + Args.push_back(FuncRef); + + // Add the type information + auto addType = [this, &Args](llvm::Type *T) { + if (T->isVoidTy()) { + // Use TokenTy as dummy for void b/c the verifier rejects a + // void arg with 'Instruction operands must be first-class values!' + // TokenTy isn't a first class value either but apparently the verifier + // doesn't mind it. + Args.push_back( + PoisonValue::get(llvm::Type::getTokenTy(getLLVMContext()))); + } else if (T->isFloatingPointTy()) { + Args.push_back(ConstantFP::get(T, 0)); + } else if (T->isIntegerTy()) { + Args.push_back(ConstantInt::get(T, 0)); + } else { + // TODO: Handle reference types here. For now, we reject them in Sema. + llvm_unreachable("Unhandled type"); + } + }; + + addType(LLVMFuncTy->getReturnType()); + for (uint i = 0; i < NParams; i++) { + addType(LLVMFuncTy->getParamType(i)); + } + Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_ref_test_func); + return Builder.CreateCall(Callee, Args); + } case WebAssembly::BI__builtin_wasm_swizzle_i8x16: { Value *Src = EmitScalarExpr(E->getArg(0)); Value *Indices = EmitScalarExpr(E->getArg(1)); diff --git a/clang/lib/Sema/SemaWasm.cpp b/clang/lib/Sema/SemaWasm.cpp index c0fa05bc17609..ca881550fad13 100644 --- a/clang/lib/Sema/SemaWasm.cpp +++ b/clang/lib/Sema/SemaWasm.cpp @@ -216,6 +216,53 @@ bool SemaWasm::BuiltinWasmTableCopy(CallExpr *TheCall) { return false; } +bool SemaWasm::BuiltinWasmTestFunctionPointerSignature(CallExpr *TheCall) { + if (SemaRef.checkArgCount(TheCall, 1)) + return true; + + Expr *FuncPtrArg = TheCall->getArg(0); + QualType ArgType = FuncPtrArg->getType(); + + // Check that the argument is a function pointer + const PointerType *PtrTy = ArgType->getAs(); + if (!PtrTy) { + return Diag(FuncPtrArg->getBeginLoc(), + diag::err_typecheck_expect_function_pointer) + << ArgType << FuncPtrArg->getSourceRange(); + } + + const FunctionProtoType *FuncTy = + PtrTy->getPointeeType()->getAs(); + if (!FuncTy) { + return Diag(FuncPtrArg->getBeginLoc(), + diag::err_typecheck_expect_function_pointer) + << ArgType << FuncPtrArg->getSourceRange(); + } + + // Check that the function pointer doesn't use reference types + if (FuncTy->getReturnType().isWebAssemblyReferenceType()) { + return Diag( + FuncPtrArg->getBeginLoc(), + diag::err_wasm_builtin_test_fp_sig_cannot_include_reference_type) + << 0 << FuncTy->getReturnType() << FuncPtrArg->getSourceRange(); + } + auto NParams = FuncTy->getNumParams(); + for (unsigned I = 0; I < NParams; I++) { + if (FuncTy->getParamType(I).isWebAssemblyReferenceType()) { + return Diag( + FuncPtrArg->getBeginLoc(), + diag:: + err_wasm_builtin_test_fp_sig_cannot_include_reference_type) + << 1 << FuncPtrArg->getSourceRange(); + } + } + + // Set return type to int (the result of the test) + TheCall->setType(getASTContext().IntTy); + + return false; +} + bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID, CallExpr *TheCall) { @@ -236,6 +283,8 @@ bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const TargetInfo &TI, return BuiltinWasmTableFill(TheCall); case WebAssembly::BI__builtin_wasm_table_copy: return BuiltinWasmTableCopy(TheCall); + case WebAssembly::BI__builtin_wasm_test_function_pointer_signature: + return BuiltinWasmTestFunctionPointerSignature(TheCall); } return false; diff --git a/clang/test/CodeGen/builtins-wasm.c b/clang/test/CodeGen/builtins-wasm.c index 263cfd3ab4c69..554a3555f2f57 100644 --- a/clang/test/CodeGen/builtins-wasm.c +++ b/clang/test/CodeGen/builtins-wasm.c @@ -745,3 +745,27 @@ void *tp (void) { return __builtin_thread_pointer (); // WEBASSEMBLY: call {{.*}} @llvm.thread.pointer() } + + +typedef void (*funcref_t)(); +typedef int (*funcref_int_t)(int); +typedef float (*F1)(float, double, int); +typedef int (*F2)(float, double, int); +typedef int (*F3)(int, int, int); +typedef void (*F4)(int, int, int); +typedef void (*F5)(void); + +void use(int); + +void test_function_pointer_signature_void(F1 func) { + // WEBASSEMBLY: %0 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.000000e+00, float 0.000000e+00, double 0.000000e+00, i32 0) + use(__builtin_wasm_test_function_pointer_signature(func)); + // WEBASSEMBLY: %1 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, float 0.000000e+00, double 0.000000e+00, i32 0) + use(__builtin_wasm_test_function_pointer_signature((F2)func)); + // WEBASSEMBLY: %2 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i32 0, i32 0, i32 0) + use(__builtin_wasm_test_function_pointer_signature((F3)func)); + // WEBASSEMBLY: %3 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, i32 0, i32 0, i32 0) + use(__builtin_wasm_test_function_pointer_signature((F4)func)); + // WEBASSEMBLY: %4 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison) + use(__builtin_wasm_test_function_pointer_signature((F5)func)); +} diff --git a/clang/test/Sema/builtins-wasm.c b/clang/test/Sema/builtins-wasm.c index beb430616233a..de9f493f34833 100644 --- a/clang/test/Sema/builtins-wasm.c +++ b/clang/test/Sema/builtins-wasm.c @@ -50,3 +50,27 @@ void test_table_copy(int dst_idx, int src_idx, int nelem) { __builtin_wasm_table_copy(table, table, dst_idx, src_idx, table); // expected-error {{5th argument must be an integer}} __builtin_wasm_table_copy(table, table, dst_idx, src_idx, nelem); } + +typedef void (*F1)(void); +typedef int (*F2)(int); +typedef int (*F3)(__externref_t); +typedef __externref_t (*F4)(int); + +void test_function_pointer_signature() { + // Test argument count validation + (void)__builtin_wasm_test_function_pointer_signature(); // expected-error {{too few arguments to function call, expected 1, have 0}} + (void)__builtin_wasm_test_function_pointer_signature((F1)0, (F2)0); // expected-error {{too many arguments to function call, expected 1, have 2}} + + // // Test argument type validation - should require function pointer + (void)__builtin_wasm_test_function_pointer_signature((void*)0); // expected-error {{used type 'void *' where function pointer is required}} + (void)__builtin_wasm_test_function_pointer_signature((int)0); // expected-error {{used type 'int' where function pointer is required}} + (void)__builtin_wasm_test_function_pointer_signature((F3)0); // expected-error {{not supported for function pointers with a reference type parameter}} + (void)__builtin_wasm_test_function_pointer_signature((F4)0); // expected-error {{not supported for function pointers with a reference type return value}} + + // // Test valid usage + int res = __builtin_wasm_test_function_pointer_signature((F1)0); + res = __builtin_wasm_test_function_pointer_signature((F2)0); + + // Test return type + _Static_assert(EXPR_HAS_TYPE(__builtin_wasm_test_function_pointer_signature((F1)0), int), ""); +} diff --git a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td index f592ff287a0e3..fb61d8a11e5c0 100644 --- a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td +++ b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td @@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem], "llvm.wasm.ref.is_null.exn">; +def int_wasm_ref_test_func + : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty], + [IntrNoMem], "llvm.wasm.ref.test.func">; + //===----------------------------------------------------------------------===// // Table intrinsics //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp index 7ee6a3d8304be..c1b3936c1dcec 100644 --- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp +++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp @@ -668,6 +668,8 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser { if (parseFunctionTableOperand(&FunctionTable)) return true; ExpectFuncType = true; + } else if (Name == "ref.test") { + ExpectFuncType = true; } // Returns true if the next tokens are a catch clause diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index aac68b32da13a..f86517cb6d67c 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -18,6 +18,7 @@ #include "WebAssemblySubtarget.h" #include "WebAssemblyTargetMachine.h" #include "WebAssemblyUtilities.h" +#include "llvm/BinaryFormat/Wasm.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineInstrBuilder.h" @@ -505,6 +506,51 @@ MVT WebAssemblyTargetLowering::getScalarShiftAmountTy(const DataLayout & /*DL*/, return Result; } +static MachineBasicBlock *LowerRefTestFuncRef(MachineInstr &MI, DebugLoc DL, + MachineBasicBlock *BB, + const TargetInstrInfo &TII) { + // Lower a REF_TEST_FUNCREF_PSEUDO instruction into a REF_TEST_FUNCREF + // instruction by combining the signature info Imm operands that + // SelectionDag/InstrEmitter.cpp makes into one CImm operand. Put this into + // the type index placeholder for REF_TEST_FUNCREF + Register ResultReg = MI.getOperand(0).getReg(); + Register FuncRefReg = MI.getOperand(1).getReg(); + + auto NParams = MI.getNumOperands() - 3; + auto Sig = APInt(NParams * 64, 0); + + { + uint64_t V = MI.getOperand(2).getImm(); + Sig |= int64_t(V); + } + + for (unsigned I = 3; I < MI.getNumOperands(); I++) { + const MachineOperand &MO = MI.getOperand(I); + if (!MO.isImm()) { + // I'm not really sure what these are or where they come from but it seems + // to be okay to ignore them + continue; + } + uint16_t V = MO.getImm(); + Sig <<= 64; + Sig |= int64_t(V); + } + + ConstantInt *TypeInfo = + ConstantInt::get(BB->getParent()->getFunction().getContext(), Sig); + + // Put the type info first in the placeholder for the type index, then the + // actual funcref arg + BuildMI(*BB, MI, DL, TII.get(WebAssembly::REF_TEST_FUNCREF), ResultReg) + .addCImm(TypeInfo) + .addReg(FuncRefReg); + + // Remove the original instruction + MI.eraseFromParent(); + + return BB; +} + // Lower an fp-to-int conversion operator from the LLVM opcode, which has an // undefined result on invalid/overflow, to the WebAssembly opcode, which // traps on invalid/overflow. @@ -798,6 +844,7 @@ LowerCallResults(MachineInstr &CallResults, DebugLoc DL, MachineBasicBlock *BB, if (IsIndirect) { // Placeholder for the type index. + // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp MIB.addImm(0); // The table into which this call_indirect indexes. MCSymbolWasm *Table = IsFuncrefCall @@ -866,6 +913,8 @@ MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter( switch (MI.getOpcode()) { default: llvm_unreachable("Unexpected instr type to insert"); + case WebAssembly::REF_TEST_FUNCREF_PSEUDO: + return LowerRefTestFuncRef(MI, DL, BB, TII); case WebAssembly::FP_TO_SINT_I32_F32: return LowerFPToInt(MI, DL, BB, TII, false, false, false, WebAssembly::I32_TRUNC_S_F32); @@ -2260,6 +2309,72 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op, DAG.getTargetExternalSymbol(TlsBase, PtrVT)), 0); } + case Intrinsic::wasm_ref_test_func: { + // First emit the TABLE_GET instruction to convert function pointer ==> + // funcref + MachineFunction &MF = DAG.getMachineFunction(); + auto PtrVT = getPointerTy(MF.getDataLayout()); + MCSymbol *Table = + WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget); + SDValue TableSym = DAG.getMCSymbol(Table, PtrVT); + SDValue FuncRef = + SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL, + MVT::funcref, TableSym, Op.getOperand(1)), + 0); + + SmallVector Ops; + Ops.push_back(FuncRef); + + // We want to encode the type information into an APInt which we'll put + // in a CImm. However, in SelectionDag/InstrEmitter.cpp there is no code + // path that emits a CImm. So we need a custom inserter to put it in. + + // We'll put each type argument in a separate TargetConstant which gets + // lowered to a MachineInstruction Imm. We combine these into a CImm in our + // custom inserter because it creates a problem downstream to have all these + // extra immediates. + { + SDValue Operand = Op.getOperand(2); + MVT VT = Operand.getValueType().getSimpleVT(); + WebAssembly::BlockType V; + if (VT == MVT::Untyped) { + V = WebAssembly::BlockType::Void; + } else if (VT == MVT::i32) { + V = WebAssembly::BlockType::I32; + } else if (VT == MVT::i64) { + V = WebAssembly::BlockType::I64; + } else if (VT == MVT::f32) { + V = WebAssembly::BlockType::F32; + } else if (VT == MVT::f64) { + V = WebAssembly::BlockType::F64; + } else { + llvm_unreachable("Unhandled type!"); + } + Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64)); + } + + for (unsigned i = 3; i < Op.getNumOperands(); ++i) { + SDValue Operand = Op.getOperand(i); + MVT VT = Operand.getValueType().getSimpleVT(); + wasm::ValType V; + if (VT == MVT::i32) { + V = wasm::ValType::I32; + } else if (VT == MVT::i64) { + V = wasm::ValType::I64; + } else if (VT == MVT::f32) { + V = wasm::ValType::F32; + } else if (VT == MVT::f64) { + V = wasm::ValType::F64; + } else { + llvm_unreachable("Unhandled type!"); + } + Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64)); + } + + return SDValue(DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF_PSEUDO, DL, + MVT::i32, Ops), + 0); + } } } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td index 2654a09387fd4..0c61f5770e748 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrRef.td @@ -36,6 +36,19 @@ multiclass REF_I { Requires<[HasReferenceTypes]>; } +let usesCustomInserter = 1, isPseudo = 1 in defm REF_TEST_FUNCREF_PSEUDO + : I<(outs I32:$res), (ins TypeIndex:$type, FUNCREF:$ref, variable_ops), + (outs), (ins TypeIndex:$type), [], "ref.test.pseudo\t$type, $ref", + "ref.test.pseudo $type", -1>; + +defm REF_TEST_FUNCREF : + I<(outs I32: $res), + (ins TypeIndex:$type, FUNCREF: $ref), + (outs), + (ins TypeIndex:$type), + [], + "ref.test\t$type, $ref", "ref.test $type", 0xfb14>; + defm "" : REF_I; defm "" : REF_I; defm "" : REF_I; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp index 8c8629203bca7..61beb1f909b5f 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyMCInstLower.cpp @@ -15,13 +15,17 @@ #include "WebAssemblyMCInstLower.h" #include "MCTargetDesc/WebAssemblyMCExpr.h" #include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "MCTargetDesc/WebAssemblyMCTypeUtilities.h" #include "TargetInfo/WebAssemblyTargetInfo.h" #include "Utils/WebAssemblyTypeUtilities.h" #include "WebAssemblyAsmPrinter.h" #include "WebAssemblyMachineFunctionInfo.h" #include "WebAssemblyUtilities.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/BinaryFormat/Wasm.h" #include "llvm/CodeGen/AsmPrinter.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineOperand.h" #include "llvm/IR/Constants.h" #include "llvm/MC/MCAsmInfo.h" #include "llvm/MC/MCContext.h" @@ -198,11 +202,80 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI, MCOp = MCOperand::createReg(WAReg); break; } + case llvm::MachineOperand::MO_CImmediate: { + // Lower type index placeholder for ref.test + // Currently this is the only way that CImmediates show up so panic if we + // get confused. + unsigned DescIndex = I - NumVariadicDefs; + if (DescIndex >= Desc.NumOperands) { + llvm_unreachable("unexpected CImmediate operand"); + } + const MCOperandInfo &Info = Desc.operands()[DescIndex]; + if (Info.OperandType != WebAssembly::OPERAND_TYPEINDEX) { + llvm_unreachable("unexpected CImmediate operand"); + } + auto CImm = MO.getCImm()->getValue(); + auto NumWords = CImm.getNumWords(); + // Extract the type data we packed into the CImm in LowerRefTestFuncRef. + // We need to load the words from most significant to least significant + // order because of the way we bitshifted them in from the right. + // The return type needs special handling because it could be void. + auto ReturnType = static_cast( + CImm.extractBitsAsZExtValue(64, (NumWords - 1) * 64)); + SmallVector Returns; + switch (ReturnType) { + case WebAssembly::BlockType::Invalid: + llvm_unreachable("Invalid return type"); + case WebAssembly::BlockType::I32: + Returns = {wasm::ValType::I32}; + break; + case WebAssembly::BlockType::I64: + Returns = {wasm::ValType::I64}; + break; + case WebAssembly::BlockType::F32: + Returns = {wasm::ValType::F32}; + break; + case WebAssembly::BlockType::F64: + Returns = {wasm::ValType::F64}; + break; + case WebAssembly::BlockType::Void: + Returns = {}; + break; + case WebAssembly::BlockType::Exnref: + Returns = {wasm::ValType::EXNREF}; + break; + case WebAssembly::BlockType::Externref: + Returns = {wasm::ValType::EXTERNREF}; + break; + case WebAssembly::BlockType::Funcref: + Returns = {wasm::ValType::FUNCREF}; + break; + case WebAssembly::BlockType::V128: + Returns = {wasm::ValType::V128}; + break; + case WebAssembly::BlockType::Multivalue: { + llvm_unreachable("Invalid return type"); + } + } + SmallVector Params; + + for (int I = NumWords - 2; I >= 0; I--) { + auto Val = CImm.extractBitsAsZExtValue(64, 64 * I); + auto ParamType = static_cast(Val); + Params.push_back(ParamType); + } + MCOp = lowerTypeIndexOperand(std::move(Returns), std::move(Params)); + break; + } case MachineOperand::MO_Immediate: { unsigned DescIndex = I - NumVariadicDefs; if (DescIndex < Desc.NumOperands) { const MCOperandInfo &Info = Desc.operands()[DescIndex]; + // Replace type index placeholder with actual type index. The type index + // placeholders are Immediates and have an operand type of + // OPERAND_TYPEINDEX or OPERAND_SIGNATURE. if (Info.OperandType == WebAssembly::OPERAND_TYPEINDEX) { + // Lower type index placeholder for a CALL_INDIRECT instruction SmallVector Returns; SmallVector Params; @@ -230,6 +303,7 @@ void WebAssemblyMCInstLower::lower(const MachineInstr *MI, break; } if (Info.OperandType == WebAssembly::OPERAND_SIGNATURE) { + // Lower type index placeholder for blocks auto BT = static_cast(MO.getImm()); assert(BT != WebAssembly::BlockType::Invalid); if (BT == WebAssembly::BlockType::Multivalue) { diff --git a/llvm/test/CodeGen/WebAssembly/ref-test-func.ll b/llvm/test/CodeGen/WebAssembly/ref-test-func.ll new file mode 100644 index 0000000000000..3fc848cd167f9 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/ref-test-func.ll @@ -0,0 +1,42 @@ +; RUN: llc < %s -mcpu=mvp -mattr=+reference-types | FileCheck %s + +target triple = "wasm32-unknown-unknown" + +; CHECK-LABEL: test_function_pointer_signature_void: +; CHECK-NEXT: .functype test_function_pointer_signature_void (i32) -> () +; CHECK-NEXT: .local funcref +; CHECK: local.get 0 +; CHECK-NEXT: table.get __indirect_function_table +; CHECK-NEXT: local.tee 1 +; CHECK-NEXT: ref.test (f32, f64, i32) -> (f32) +; CHECK-NEXT: call use +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: ref.test (f32, f64, i32) -> (i32) +; CHECK-NEXT: call use +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: ref.test (i32, i32, i32) -> (i32) +; CHECK-NEXT: call use +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: ref.test (i32, i32, i32) -> () +; CHECK-NEXT: call use +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: ref.test () -> () +; CHECK-NEXT: call use + +; Function Attrs: nounwind +define void @test_function_pointer_signature_void(ptr noundef %func) local_unnamed_addr #0 { +entry: + %0 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.000000e+00, float 0.000000e+00, double 0.000000e+00, i32 0) + tail call void @use(i32 noundef %0) #3 + %1 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, float 0.000000e+00, double 0.000000e+00, i32 0) + tail call void @use(i32 noundef %1) #3 + %2 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i32 0, i32 0, i32 0) + tail call void @use(i32 noundef %2) #3 + %3 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, i32 0, i32 0, i32 0) + tail call void @use(i32 noundef %3) #3 + %4 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison) + tail call void @use(i32 noundef %4) #3 + ret void +} + +declare void @use(i32 noundef) local_unnamed_addr #1 diff --git a/llvm/test/MC/WebAssembly/reference-types.s b/llvm/test/MC/WebAssembly/reference-types.s index cfadede8295ef..08aafb23969eb 100644 --- a/llvm/test/MC/WebAssembly/reference-types.s +++ b/llvm/test/MC/WebAssembly/reference-types.s @@ -27,6 +27,21 @@ ref_null_test: drop end_function +# CHECK-LABEL: ref_test_test: +# CHECK: ref.null_func # encoding: [0xd0,0x70] +# CHECK: ref.test () -> () # encoding: [0xfb,0x14,0x80'A',0x80'A',0x80'A',0x80'A',A] +# CHECK: # fixup A - offset: 2, value: .Ltypeindex0@TYPEINDEX, kind: fixup_uleb128_i32 +# CHECK: ref.null_func # encoding: [0xd0,0x70] +# CHECK: ref.test () -> (i32) # encoding: [0xfb,0x14,0x80'A',0x80'A',0x80'A',0x80'A',A] +# CHECK: # fixup A - offset: 2, value: .Ltypeindex1@TYPEINDEX, kind: fixup_uleb128_i32 +ref_test_test: + .functype ref_test_test () -> (i32, i32) + ref.null_func + ref.test () -> () + ref.null_func + ref.test () -> (i32) + end_function + # CHECK-LABEL: ref_sig_test_funcref: # CHECK-NEXT: .functype ref_sig_test_funcref (funcref) -> (funcref) ref_sig_test_funcref: