Skip to content

Commit fa5b07c

Browse files
authored
[CIR][CUDA] Generate CUDA destructor (#1470)
This is Part 3 of registration function generation. This generates `__cuda_module_dtor`. It cannot be placed in global dtors list, as treating it as a normal destructor will result in double-free in recent CUDA versions (see comments in OG). Rather, the function is passed as callback of `atexit`, which is called at the end of `__cuda_module_ctor`.
1 parent 6c645a1 commit fa5b07c

File tree

2 files changed

+86
-9
lines changed

2 files changed

+86
-9
lines changed

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
127127
llvm::StringMap<FuncOp> cudaKernelMap;
128128

129129
void buildCUDAModuleCtor();
130-
void buildCUDAModuleDtor();
130+
std::optional<FuncOp> buildCUDAModuleDtor();
131131
std::optional<FuncOp> buildCUDARegisterGlobals();
132132

133133
void buildCUDARegisterGlobalFunctions(cir::CIRBaseBuilderTy &builder,
@@ -1153,6 +1153,23 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
11531153
builder.createCallOp(loc, endFunc, gpuBinaryHandle);
11541154
}
11551155

1156+
// Create destructor and register it with atexit() the way NVCC does it. Doing
1157+
// it during regular destructor phase worked in CUDA before 9.2 but results in
1158+
// double-free in 9.2.
1159+
if (auto dtor = buildCUDAModuleDtor()) {
1160+
// extern "C" int atexit(void (*f)(void));
1161+
cir::CIRBaseBuilderTy globalBuilder(getContext());
1162+
globalBuilder.setInsertionPointToStart(theModule.getBody());
1163+
FuncOp atexit = buildRuntimeFunction(
1164+
globalBuilder, "atexit", loc,
1165+
FuncType::get(PointerType::get(dtor->getFunctionType()), intTy));
1166+
1167+
mlir::Value dtorFunc = builder.create<GetGlobalOp>(
1168+
loc, PointerType::get(dtor->getFunctionType()),
1169+
mlir::FlatSymbolRefAttr::get(dtor->getSymNameAttr()));
1170+
builder.createCallOp(loc, atexit, dtorFunc);
1171+
}
1172+
11561173
builder.create<cir::ReturnOp>(loc);
11571174
}
11581175

@@ -1256,6 +1273,51 @@ void LoweringPreparePass::buildCUDARegisterGlobalFunctions(
12561273
}
12571274
}
12581275

1276+
std::optional<FuncOp> LoweringPreparePass::buildCUDAModuleDtor() {
1277+
if (!theModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName()))
1278+
return {};
1279+
1280+
std::string prefix = getCUDAPrefix(astCtx);
1281+
1282+
auto voidTy = VoidType::get(&getContext());
1283+
auto voidPtrPtrTy = PointerType::get(PointerType::get(voidTy));
1284+
1285+
auto loc = theModule.getLoc();
1286+
1287+
cir::CIRBaseBuilderTy builder(getContext());
1288+
builder.setInsertionPointToStart(theModule.getBody());
1289+
1290+
// void __cudaUnregisterFatBinary(void ** handle);
1291+
std::string unregisterFuncName =
1292+
addUnderscoredPrefix(prefix, "UnregisterFatBinary");
1293+
FuncOp unregisterFunc = buildRuntimeFunction(
1294+
builder, unregisterFuncName, loc, FuncType::get({voidPtrPtrTy}, voidTy));
1295+
1296+
// void __cuda_module_dtor();
1297+
// Despite the name, OG doesn't treat it as a destructor, so it shouldn't be
1298+
// put into globalDtorList. If it were a real dtor, then it would cause double
1299+
// free above CUDA 9.2. The way to use it is to manually call atexit() at end
1300+
// of module ctor.
1301+
std::string dtorName = addUnderscoredPrefix(prefix, "_module_dtor");
1302+
FuncOp dtor =
1303+
buildRuntimeFunction(builder, dtorName, loc, FuncType::get({}, voidTy),
1304+
GlobalLinkageKind::InternalLinkage);
1305+
1306+
builder.setInsertionPointToStart(dtor.addEntryBlock());
1307+
1308+
// For dtor, we only need to call:
1309+
// __cudaUnregisterFatBinary(__cuda_gpubin_handle);
1310+
1311+
std::string gpubinName = addUnderscoredPrefix(prefix, "_gpubin_handle");
1312+
auto gpubinGlobal = cast<GlobalOp>(theModule.lookupSymbol(gpubinName));
1313+
mlir::Value gpubinAddress = builder.createGetGlobal(gpubinGlobal);
1314+
mlir::Value gpubin = builder.createLoad(loc, gpubinAddress);
1315+
builder.createCallOp(loc, unregisterFunc, gpubin);
1316+
builder.create<ReturnOp>(loc);
1317+
1318+
return dtor;
1319+
}
1320+
12591321
void LoweringPreparePass::lowerDynamicCastOp(DynamicCastOp op) {
12601322
CIRBaseBuilderTy builder(getContext());
12611323
builder.setInsertionPointAfter(op);
@@ -1537,9 +1599,6 @@ void LoweringPreparePass::runOnOperation() {
15371599
datalayout.emplace(theModule);
15381600
}
15391601

1540-
auto typeSizeInfo = cast<TypeSizeInfoAttr>(
1541-
theModule->getAttr(CIRDialect::getTypeSizeInfoAttrName()));
1542-
15431602
llvm::SmallVector<Operation *> opsToTransform;
15441603

15451604
op->walk([&](Operation *op) {

clang/test/CIR/CodeGen/CUDA/registration.cu

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@
1818
// CIR-HOST: cir.global_ctors = [#cir.global_ctor<"__cuda_module_ctor", {{[0-9]+}}>]
1919
// CIR-HOST: }
2020

21+
// Module destructor goes here.
22+
// This is not a real destructor, as explained in LoweringPrepare.
23+
24+
// CIR-HOST: cir.func internal private @__cuda_module_dtor() {
25+
// CIR-HOST: %[[#HandleGlobal:]] = cir.get_global @__cuda_gpubin_handle
26+
// CIR-HOST: %[[#Handle:]] = cir.load %0
27+
// CIR-HOST: cir.call @__cudaUnregisterFatBinary(%[[#Handle]])
28+
// CIR-HOST: }
29+
2130
// CIR-HOST: cir.global "private" constant cir_private @".str_Z2fnv" =
2231
// CIR-HOST-SAME: #cir.const_array<"_Z2fnv", trailing_zeros>
2332

@@ -33,6 +42,12 @@
3342
// LLVM-HOST: }
3443
// LLVM-HOST: @llvm.global_ctors = {{.*}}ptr @__cuda_module_ctor
3544

45+
// LLVM-HOST: define internal void @__cuda_module_dtor() {
46+
// LLVM-HOST: %[[#LLVMHandleVar:]] = load ptr, ptr @__cuda_gpubin_handle, align 8
47+
// LLVM-HOST: call void @__cudaUnregisterFatBinary(ptr %[[#LLVMHandleVar]])
48+
// LLVM-HOST: ret void
49+
// LLVM-HOST: }
50+
3651
__global__ void fn() {}
3752

3853
// CIR-HOST: cir.func internal private @__cuda_register_globals(%[[FatbinHandle:[a-zA-Z0-9]+]]{{.*}}) {
@@ -83,12 +98,15 @@ __global__ void fn() {}
8398
// CIR-HOST: %[[#FatbinGlobal:]] = cir.get_global @__cuda_gpubin_handle
8499
// CIR-HOST: cir.store %[[#Fatbin]], %[[#FatbinGlobal]]
85100
// CIR-HOST: cir.call @__cuda_register_globals
86-
// CIR-HOTS: cir.call @__cudaRegisterFatBinaryEnd
101+
// CIR-HOST: cir.call @__cudaRegisterFatBinaryEnd
102+
// CIR-HOST: %[[#ModuleDtor:]] = cir.get_global @__cuda_module_dtor
103+
// CIR-HOST: cir.call @atexit(%[[#ModuleDtor]])
87104
// CIR-HOST: }
88105

89106
// LLVM-HOST: define internal void @__cuda_module_ctor() {
90-
// LLVM-HOST: %[[#LLVMFatbin:]] = call ptr @__cudaRegisterFatBinary(ptr @__cuda_fatbin_wrapper)
91-
// LLVM-HOST: store ptr %[[#LLVMFatbin]], ptr @__cuda_gpubin_handle
92-
// LLVM-HOST: call void @__cuda_register_globals
93-
// LLVM-HOST: call void @__cudaRegisterFatBinaryEnd
107+
// LLVM-HOST: %[[#LLVMFatbin:]] = call ptr @__cudaRegisterFatBinary(ptr @__cuda_fatbin_wrapper)
108+
// LLVM-HOST: store ptr %[[#LLVMFatbin]], ptr @__cuda_gpubin_handle
109+
// LLVM-HOST: call void @__cuda_register_globals
110+
// LLVM-HOST: call void @__cudaRegisterFatBinaryEnd
111+
// LLVM-HOST: call i32 @atexit(ptr @__cuda_module_dtor)
94112
// LLVM-HOST: }

0 commit comments

Comments
 (0)