diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index dbcdb4a3af87..7677e61ea614 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -465,31 +465,31 @@ llvm::Value* CodeGenCPU::CreateCallExtern(Type ret_type, String global_symbol, } llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), arg_types, false); // Check if it is available in global function table as injected function. - auto it = gv_func_map_.find(global_symbol); - if (it != gv_func_map_.end()) { - if (it->second == nullptr) { - gv_func_map_[global_symbol] = InitContextPtr(ftype->getPointerTo(), "__" + global_symbol); - it = gv_func_map_.find(global_symbol); - } -#if TVM_LLVM_VERSION >= 90 - auto ext_callee = llvm::FunctionCallee(ftype, GetContextPtr(it->second)); -#else - auto ext_callee = GetContextPtr(it->second); -#endif - return builder_->CreateCall(ext_callee, arg_values); - } else { - llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol)); - if (f == nullptr) { - f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, - MakeStringRef(global_symbol), module_.get()); + + auto callee = [&]() -> llvm::Value* { + if (auto it = gv_func_map_.find(global_symbol); it != gv_func_map_.end()) { + if (it->second == nullptr) { + it->second = InitContextPtr(ftype->getPointerTo(), "__" + global_symbol); + } + return GetContextPtr(it->second); + } else if (llvm::Function* f = module_->getFunction(MakeStringRef(global_symbol))) { + return f; + } else { + return llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + MakeStringRef(global_symbol), module_.get()); } + }(); + + if (callee->getType() != ftype->getPointerTo()) { + callee = builder_->CreatePointerCast(callee, ftype->getPointerTo()); + } + #if TVM_LLVM_VERSION >= 90 - auto ext_callee = llvm::FunctionCallee(f); + auto ext_callee = llvm::FunctionCallee(ftype, callee); #else - auto ext_callee = f; + auto ext_callee = f; #endif - return builder_->CreateCall(ext_callee, arg_values); - } + return builder_->CreateCall(ext_callee, arg_values); } llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string name) { diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 01e25d536118..eb53e9b6dc87 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -548,10 +548,11 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const { if (auto* ptr = type.as()) { return DTypeToLLVMType(ptr->dtype); } else if (auto* ptr = type.as()) { - // LLVM IR doesn't allow void*, so we need to recognize this - // pattern explicitly. + // LLVM IR doesn't allow void*, nor do we require custom datatypes + // to have LLVM equivalents, so we need to recognize these + // patterns explicitly. if (auto* primtype = ptr->element_type.as()) { - if (primtype->dtype.is_void()) { + if (primtype->dtype.is_void() || primtype->dtype.code() >= DataType::kCustomBegin) { return t_void_p_; } } @@ -1975,6 +1976,22 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { } } llvm::Value* value = MakeValue(op->value); + + // TIR has type-annotations on variables, but not on each PrimExpr. + // Therefore, to have the correct LLVM type for pointers, we may + // need to introduce a pointer-cast, even though pointer-to-pointer + // casts are not expressible with the `tir::CastNode`. + if (v->dtype.is_handle() && v->type_annotation.defined()) { + CHECK(op->value->dtype.is_handle()) + << "Variable " << op->var << " is a pointer with type " << op->value + << ", but is being bound to expression with type " << op->value->dtype; + auto* llvm_type = GetLLVMType(v->type_annotation); + if (llvm_type != value->getType()) { + value->setName((v->name_hint + "_void_ptr").c_str()); + value = builder_->CreatePointerCast(value, llvm_type); + } + } + value->setName(v->name_hint.c_str()); var_map_[v] = value; analyzer_->Bind(op->var, op->value);