Skip to content

Add __builtin_wasm_test_function_pointer_signature #147076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/BuiltinsWebAssembly.def
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ TARGET_BUILTIN(__builtin_wasm_ref_null_extern, "i", "nct", "reference-types")
// return type.
TARGET_BUILTIN(__builtin_wasm_ref_null_func, "i", "nct", "reference-types")

// Check if the static type of a function pointer matches its static type. Used
// to avoid "function signature mismatch" traps. Takes a function pointer, uses
// table.get to look up the pointer in __indirect_function_table and then
// ref.test to test the type.
TARGET_BUILTIN(__builtin_wasm_test_function_pointer_signature, "i.", "nct", "reference-types")

// Table builtins
TARGET_BUILTIN(__builtin_wasm_table_set, "viii", "t", "reference-types")
TARGET_BUILTIN(__builtin_wasm_table_get, "iii", "t", "reference-types")
Expand Down
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -7460,6 +7460,8 @@ def err_typecheck_illegal_increment_decrement : Error<
"cannot %select{decrement|increment}1 value of type %0">;
def err_typecheck_expect_int : Error<
"used type %0 where integer is required">;
def err_typecheck_expect_function_pointer
: Error<"used type %0 where function pointer is required">;
def err_typecheck_expect_hlsl_resource : Error<
"used type %0 where __hlsl_resource_t is required">;
def err_typecheck_arithmetic_incomplete_or_sizeless_type : Error<
Expand Down Expand Up @@ -12995,6 +12997,10 @@ def err_wasm_builtin_arg_must_match_table_element_type : Error <
"%ordinal0 argument must match the element type of the WebAssembly table in the %ordinal1 argument">;
def err_wasm_builtin_arg_must_be_integer_type : Error <
"%ordinal0 argument must be an integer">;
def err_wasm_builtin_test_fp_sig_cannot_include_reference_type
: Error<"not supported for "
"function pointers with a reference type %select{return "
"value|parameter}0">;

// OpenACC diagnostics.
def warn_acc_routine_unimplemented
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Sema/SemaWasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class SemaWasm : public SemaBase {
bool BuiltinWasmTableGrow(CallExpr *TheCall);
bool BuiltinWasmTableFill(CallExpr *TheCall);
bool BuiltinWasmTableCopy(CallExpr *TheCall);
bool BuiltinWasmTestFunctionPointerSignature(CallExpr *TheCall);

WebAssemblyImportNameAttr *
mergeImportNameAttr(Decl *D, const WebAssemblyImportNameAttr &AL);
Expand Down
58 changes: 58 additions & 0 deletions clang/lib/CodeGen/TargetBuiltins/WebAssembly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

#include "CGBuiltin.h"
#include "clang/Basic/TargetBuiltins.h"
#include "llvm/ADT/APInt.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IntrinsicsWebAssembly.h"
#include "llvm/Support/ErrorHandling.h"

using namespace clang;
using namespace CodeGen;
Expand Down Expand Up @@ -213,6 +216,61 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_ref_null_func);
return Builder.CreateCall(Callee);
}
case WebAssembly::BI__builtin_wasm_test_function_pointer_signature: {
Value *FuncRef = EmitScalarExpr(E->getArg(0));

// Get the function type from the argument's static type
QualType ArgType = E->getArg(0)->getType();
const PointerType *PtrTy = ArgType->getAs<PointerType>();
assert(PtrTy && "Sema should have ensured this is a function pointer");

const FunctionType *FuncTy = PtrTy->getPointeeType()->getAs<FunctionType>();
assert(FuncTy && "Sema should have ensured this is a function pointer");

// In the llvm IR, we won't have access anymore to the type of the function
// pointer so we need to insert this type information somehow. We gave the
// @llvm.wasm.ref.test.func varargs and here we add an extra 0 argument of
// the type corresponding to the type of each argument of the function
// signature. When we lower from the IR we'll use the types of these
// arguments to determine the signature we want to test for.

// Make a type index constant with 0. This gets replaced by the actual type
// in WebAssemblyMCInstLower.cpp.
llvm::FunctionType *LLVMFuncTy =
cast<llvm::FunctionType>(ConvertType(QualType(FuncTy, 0)));

uint NParams = LLVMFuncTy->getNumParams();
std::vector<Value *> Args;
Args.reserve(NParams + 1);
// The only real argument is the FuncRef
Args.push_back(FuncRef);

// Add the type information
auto addType = [this, &Args](llvm::Type *T) {
if (T->isVoidTy()) {
// Use TokenTy as dummy for void b/c the verifier rejects a
// void arg with 'Instruction operands must be first-class values!'
// TokenTy isn't a first class value either but apparently the verifier
// doesn't mind it.
Args.push_back(
PoisonValue::get(llvm::Type::getTokenTy(getLLVMContext())));
} else if (T->isFloatingPointTy()) {
Args.push_back(ConstantFP::get(T, 0));
} else if (T->isIntegerTy()) {
Args.push_back(ConstantInt::get(T, 0));
} else {
// TODO: Handle reference types here. For now, we reject them in Sema.
llvm_unreachable("Unhandled type");
}
};

addType(LLVMFuncTy->getReturnType());
for (uint i = 0; i < NParams; i++) {
addType(LLVMFuncTy->getParamType(i));
}
Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_ref_test_func);
return Builder.CreateCall(Callee, Args);
}
case WebAssembly::BI__builtin_wasm_swizzle_i8x16: {
Value *Src = EmitScalarExpr(E->getArg(0));
Value *Indices = EmitScalarExpr(E->getArg(1));
Expand Down
49 changes: 49 additions & 0 deletions clang/lib/Sema/SemaWasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,53 @@ bool SemaWasm::BuiltinWasmTableCopy(CallExpr *TheCall) {
return false;
}

bool SemaWasm::BuiltinWasmTestFunctionPointerSignature(CallExpr *TheCall) {
if (SemaRef.checkArgCount(TheCall, 1))
return true;

Expr *FuncPtrArg = TheCall->getArg(0);
QualType ArgType = FuncPtrArg->getType();

// Check that the argument is a function pointer
const PointerType *PtrTy = ArgType->getAs<PointerType>();
if (!PtrTy) {
return Diag(FuncPtrArg->getBeginLoc(),
diag::err_typecheck_expect_function_pointer)
<< ArgType << FuncPtrArg->getSourceRange();
}

const FunctionProtoType *FuncTy =
PtrTy->getPointeeType()->getAs<FunctionProtoType>();
if (!FuncTy) {
return Diag(FuncPtrArg->getBeginLoc(),
diag::err_typecheck_expect_function_pointer)
<< ArgType << FuncPtrArg->getSourceRange();
}

// Check that the function pointer doesn't use reference types
if (FuncTy->getReturnType().isWebAssemblyReferenceType()) {
return Diag(
FuncPtrArg->getBeginLoc(),
diag::err_wasm_builtin_test_fp_sig_cannot_include_reference_type)
<< 0 << FuncTy->getReturnType() << FuncPtrArg->getSourceRange();
}
auto NParams = FuncTy->getNumParams();
for (unsigned I = 0; I < NParams; I++) {
if (FuncTy->getParamType(I).isWebAssemblyReferenceType()) {
return Diag(
FuncPtrArg->getBeginLoc(),
diag::
err_wasm_builtin_test_fp_sig_cannot_include_reference_type)
<< 1 << FuncPtrArg->getSourceRange();
}
}

// Set return type to int (the result of the test)
TheCall->setType(getASTContext().IntTy);

return false;
}

bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const TargetInfo &TI,
unsigned BuiltinID,
CallExpr *TheCall) {
Expand All @@ -236,6 +283,8 @@ bool SemaWasm::CheckWebAssemblyBuiltinFunctionCall(const TargetInfo &TI,
return BuiltinWasmTableFill(TheCall);
case WebAssembly::BI__builtin_wasm_table_copy:
return BuiltinWasmTableCopy(TheCall);
case WebAssembly::BI__builtin_wasm_test_function_pointer_signature:
return BuiltinWasmTestFunctionPointerSignature(TheCall);
}

return false;
Expand Down
24 changes: 24 additions & 0 deletions clang/test/CodeGen/builtins-wasm.c
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,27 @@ void *tp (void) {
return __builtin_thread_pointer ();
// WEBASSEMBLY: call {{.*}} @llvm.thread.pointer()
}


typedef void (*funcref_t)();
typedef int (*funcref_int_t)(int);
typedef float (*F1)(float, double, int);
typedef int (*F2)(float, double, int);
typedef int (*F3)(int, int, int);
typedef void (*F4)(int, int, int);
typedef void (*F5)(void);

void use(int);

void test_function_pointer_signature_void(F1 func) {
// WEBASSEMBLY: %0 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, float 0.000000e+00, float 0.000000e+00, double 0.000000e+00, i32 0)
use(__builtin_wasm_test_function_pointer_signature(func));
// WEBASSEMBLY: %1 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, float 0.000000e+00, double 0.000000e+00, i32 0)
use(__builtin_wasm_test_function_pointer_signature((F2)func));
// WEBASSEMBLY: %2 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, i32 0, i32 0, i32 0, i32 0)
use(__builtin_wasm_test_function_pointer_signature((F3)func));
// WEBASSEMBLY: %3 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison, i32 0, i32 0, i32 0)
use(__builtin_wasm_test_function_pointer_signature((F4)func));
// WEBASSEMBLY: %4 = tail call i32 (ptr, ...) @llvm.wasm.ref.test.func(ptr %func, token poison)
use(__builtin_wasm_test_function_pointer_signature((F5)func));
}
24 changes: 24 additions & 0 deletions clang/test/Sema/builtins-wasm.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,27 @@ void test_table_copy(int dst_idx, int src_idx, int nelem) {
__builtin_wasm_table_copy(table, table, dst_idx, src_idx, table); // expected-error {{5th argument must be an integer}}
__builtin_wasm_table_copy(table, table, dst_idx, src_idx, nelem);
}

typedef void (*F1)(void);
typedef int (*F2)(int);
typedef int (*F3)(__externref_t);
typedef __externref_t (*F4)(int);

void test_function_pointer_signature() {
// Test argument count validation
(void)__builtin_wasm_test_function_pointer_signature(); // expected-error {{too few arguments to function call, expected 1, have 0}}
(void)__builtin_wasm_test_function_pointer_signature((F1)0, (F2)0); // expected-error {{too many arguments to function call, expected 1, have 2}}

// // Test argument type validation - should require function pointer
(void)__builtin_wasm_test_function_pointer_signature((void*)0); // expected-error {{used type 'void *' where function pointer is required}}
(void)__builtin_wasm_test_function_pointer_signature((int)0); // expected-error {{used type 'int' where function pointer is required}}
(void)__builtin_wasm_test_function_pointer_signature((F3)0); // expected-error {{not supported for function pointers with a reference type parameter}}
(void)__builtin_wasm_test_function_pointer_signature((F4)0); // expected-error {{not supported for function pointers with a reference type return value}}

// // Test valid usage
int res = __builtin_wasm_test_function_pointer_signature((F1)0);
res = __builtin_wasm_test_function_pointer_signature((F2)0);

// Test return type
_Static_assert(EXPR_HAS_TYPE(__builtin_wasm_test_function_pointer_signature((F1)0), int), "");
}
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsWebAssembly.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def int_wasm_ref_is_null_exn :
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_exnref_ty], [IntrNoMem],
"llvm.wasm.ref.is_null.exn">;

def int_wasm_ref_test_func
: DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_ptr_ty, llvm_vararg_ty],
[IntrNoMem], "llvm.wasm.ref.test.func">;

//===----------------------------------------------------------------------===//
// Table intrinsics
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,8 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser {
if (parseFunctionTableOperand(&FunctionTable))
return true;
ExpectFuncType = true;
} else if (Name == "ref.test") {
ExpectFuncType = true;
}

// Returns true if the next tokens are a catch clause
Expand Down
114 changes: 114 additions & 0 deletions llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "WebAssemblySubtarget.h"
#include "WebAssemblyTargetMachine.h"
#include "WebAssemblyUtilities.h"
#include "llvm/BinaryFormat/Wasm.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
Expand Down Expand Up @@ -505,6 +506,51 @@ MVT WebAssemblyTargetLowering::getScalarShiftAmountTy(const DataLayout & /*DL*/,
return Result;
}

static MachineBasicBlock *LowerRefTestFuncRef(MachineInstr &MI, DebugLoc DL,
MachineBasicBlock *BB,
const TargetInstrInfo &TII) {
// Lower a REF_TEST_FUNCREF_PSEUDO instruction into a REF_TEST_FUNCREF
// instruction by combining the signature info Imm operands that
// SelectionDag/InstrEmitter.cpp makes into one CImm operand. Put this into
// the type index placeholder for REF_TEST_FUNCREF
Register ResultReg = MI.getOperand(0).getReg();
Register FuncRefReg = MI.getOperand(1).getReg();

auto NParams = MI.getNumOperands() - 3;
auto Sig = APInt(NParams * 64, 0);

{
uint64_t V = MI.getOperand(2).getImm();
Sig |= int64_t(V);
}

for (unsigned I = 3; I < MI.getNumOperands(); I++) {
const MachineOperand &MO = MI.getOperand(I);
if (!MO.isImm()) {
// I'm not really sure what these are or where they come from but it seems
// to be okay to ignore them
continue;
}
uint16_t V = MO.getImm();
Sig <<= 64;
Sig |= int64_t(V);
}

ConstantInt *TypeInfo =
ConstantInt::get(BB->getParent()->getFunction().getContext(), Sig);

// Put the type info first in the placeholder for the type index, then the
// actual funcref arg
BuildMI(*BB, MI, DL, TII.get(WebAssembly::REF_TEST_FUNCREF), ResultReg)
.addCImm(TypeInfo)
.addReg(FuncRefReg);

// Remove the original instruction
MI.eraseFromParent();

return BB;
}

// Lower an fp-to-int conversion operator from the LLVM opcode, which has an
// undefined result on invalid/overflow, to the WebAssembly opcode, which
// traps on invalid/overflow.
Expand Down Expand Up @@ -866,6 +912,8 @@ MachineBasicBlock *WebAssemblyTargetLowering::EmitInstrWithCustomInserter(
switch (MI.getOpcode()) {
default:
llvm_unreachable("Unexpected instr type to insert");
case WebAssembly::REF_TEST_FUNCREF_PSEUDO:
return LowerRefTestFuncRef(MI, DL, BB, TII);
case WebAssembly::FP_TO_SINT_I32_F32:
return LowerFPToInt(MI, DL, BB, TII, false, false, false,
WebAssembly::I32_TRUNC_S_F32);
Expand Down Expand Up @@ -2260,6 +2308,72 @@ SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
DAG.getTargetExternalSymbol(TlsBase, PtrVT)),
0);
}
case Intrinsic::wasm_ref_test_func: {
// First emit the TABLE_GET instruction to convert function pointer ==>
// funcref
MachineFunction &MF = DAG.getMachineFunction();
auto PtrVT = getPointerTy(MF.getDataLayout());
MCSymbol *Table =
WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), Subtarget);
SDValue TableSym = DAG.getMCSymbol(Table, PtrVT);
SDValue FuncRef =
SDValue(DAG.getMachineNode(WebAssembly::TABLE_GET_FUNCREF, DL,
MVT::funcref, TableSym, Op.getOperand(1)),
0);

SmallVector<SDValue, 4> Ops;
Ops.push_back(FuncRef);

// We want to encode the type information into an APInt which we'll put
// in a CImm. However, in SelectionDag/InstrEmitter.cpp there is no code
// path that emits a CImm. So we need a custom inserter to put it in.

// We'll put each type argument in a separate TargetConstant which gets
// lowered to a MachineInstruction Imm. We combine these into a CImm in our
// custom inserter because it creates a problem downstream to have all these
// extra immediates.
{
SDValue Operand = Op.getOperand(2);
MVT VT = Operand.getValueType().getSimpleVT();
WebAssembly::BlockType V;
if (VT == MVT::Untyped) {
V = WebAssembly::BlockType::Void;
} else if (VT == MVT::i32) {
V = WebAssembly::BlockType::I32;
} else if (VT == MVT::i64) {
V = WebAssembly::BlockType::I64;
} else if (VT == MVT::f32) {
V = WebAssembly::BlockType::F32;
} else if (VT == MVT::f64) {
V = WebAssembly::BlockType::F64;
} else {
llvm_unreachable("Unhandled type!");
}
Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64));
}

for (unsigned i = 3; i < Op.getNumOperands(); ++i) {
SDValue Operand = Op.getOperand(i);
MVT VT = Operand.getValueType().getSimpleVT();
wasm::ValType V;
if (VT == MVT::i32) {
V = wasm::ValType::I32;
} else if (VT == MVT::i64) {
V = wasm::ValType::I64;
} else if (VT == MVT::f32) {
V = wasm::ValType::F32;
} else if (VT == MVT::f64) {
V = wasm::ValType::F64;
} else {
llvm_unreachable("Unhandled type!");
}
Ops.push_back(DAG.getTargetConstant((int64_t)V, DL, MVT::i64));
}

return SDValue(DAG.getMachineNode(WebAssembly::REF_TEST_FUNCREF_PSEUDO, DL,
MVT::i32, Ops),
0);
}
}
}

Expand Down
Loading
Loading