@@ -127,7 +127,7 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
127
127
llvm::StringMap<FuncOp> cudaKernelMap;
128
128
129
129
void buildCUDAModuleCtor ();
130
- void buildCUDAModuleDtor ();
130
+ std::optional<FuncOp> buildCUDAModuleDtor ();
131
131
std::optional<FuncOp> buildCUDARegisterGlobals ();
132
132
133
133
void buildCUDARegisterGlobalFunctions (cir::CIRBaseBuilderTy &builder,
@@ -1153,6 +1153,23 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
1153
1153
builder.createCallOp (loc, endFunc, gpuBinaryHandle);
1154
1154
}
1155
1155
1156
+ // Create destructor and register it with atexit() the way NVCC does it. Doing
1157
+ // it during regular destructor phase worked in CUDA before 9.2 but results in
1158
+ // double-free in 9.2.
1159
+ if (auto dtor = buildCUDAModuleDtor ()) {
1160
+ // extern "C" int atexit(void (*f)(void));
1161
+ cir::CIRBaseBuilderTy globalBuilder (getContext ());
1162
+ globalBuilder.setInsertionPointToStart (theModule.getBody ());
1163
+ FuncOp atexit = buildRuntimeFunction (
1164
+ globalBuilder, " atexit" , loc,
1165
+ FuncType::get (PointerType::get (dtor->getFunctionType ()), intTy));
1166
+
1167
+ mlir::Value dtorFunc = builder.create <GetGlobalOp>(
1168
+ loc, PointerType::get (dtor->getFunctionType ()),
1169
+ mlir::FlatSymbolRefAttr::get (dtor->getSymNameAttr ()));
1170
+ builder.createCallOp (loc, atexit , dtorFunc);
1171
+ }
1172
+
1156
1173
builder.create <cir::ReturnOp>(loc);
1157
1174
}
1158
1175
@@ -1256,6 +1273,51 @@ void LoweringPreparePass::buildCUDARegisterGlobalFunctions(
1256
1273
}
1257
1274
}
1258
1275
1276
+ std::optional<FuncOp> LoweringPreparePass::buildCUDAModuleDtor () {
1277
+ if (!theModule->getAttr (CIRDialect::getCUDABinaryHandleAttrName ()))
1278
+ return {};
1279
+
1280
+ std::string prefix = getCUDAPrefix (astCtx);
1281
+
1282
+ auto voidTy = VoidType::get (&getContext ());
1283
+ auto voidPtrPtrTy = PointerType::get (PointerType::get (voidTy));
1284
+
1285
+ auto loc = theModule.getLoc ();
1286
+
1287
+ cir::CIRBaseBuilderTy builder (getContext ());
1288
+ builder.setInsertionPointToStart (theModule.getBody ());
1289
+
1290
+ // void __cudaUnregisterFatBinary(void ** handle);
1291
+ std::string unregisterFuncName =
1292
+ addUnderscoredPrefix (prefix, " UnregisterFatBinary" );
1293
+ FuncOp unregisterFunc = buildRuntimeFunction (
1294
+ builder, unregisterFuncName, loc, FuncType::get ({voidPtrPtrTy}, voidTy));
1295
+
1296
+ // void __cuda_module_dtor();
1297
+ // Despite the name, OG doesn't treat it as a destructor, so it shouldn't be
1298
+ // put into globalDtorList. If it were a real dtor, then it would cause double
1299
+ // free above CUDA 9.2. The way to use it is to manually call atexit() at end
1300
+ // of module ctor.
1301
+ std::string dtorName = addUnderscoredPrefix (prefix, " _module_dtor" );
1302
+ FuncOp dtor =
1303
+ buildRuntimeFunction (builder, dtorName, loc, FuncType::get ({}, voidTy),
1304
+ GlobalLinkageKind::InternalLinkage);
1305
+
1306
+ builder.setInsertionPointToStart (dtor.addEntryBlock ());
1307
+
1308
+ // For dtor, we only need to call:
1309
+ // __cudaUnregisterFatBinary(__cuda_gpubin_handle);
1310
+
1311
+ std::string gpubinName = addUnderscoredPrefix (prefix, " _gpubin_handle" );
1312
+ auto gpubinGlobal = cast<GlobalOp>(theModule.lookupSymbol (gpubinName));
1313
+ mlir::Value gpubinAddress = builder.createGetGlobal (gpubinGlobal);
1314
+ mlir::Value gpubin = builder.createLoad (loc, gpubinAddress);
1315
+ builder.createCallOp (loc, unregisterFunc, gpubin);
1316
+ builder.create <ReturnOp>(loc);
1317
+
1318
+ return dtor;
1319
+ }
1320
+
1259
1321
void LoweringPreparePass::lowerDynamicCastOp (DynamicCastOp op) {
1260
1322
CIRBaseBuilderTy builder (getContext ());
1261
1323
builder.setInsertionPointAfter (op);
@@ -1537,9 +1599,6 @@ void LoweringPreparePass::runOnOperation() {
1537
1599
datalayout.emplace (theModule);
1538
1600
}
1539
1601
1540
- auto typeSizeInfo = cast<TypeSizeInfoAttr>(
1541
- theModule->getAttr (CIRDialect::getTypeSizeInfoAttrName ()));
1542
-
1543
1602
llvm::SmallVector<Operation *> opsToTransform;
1544
1603
1545
1604
op->walk ([&](Operation *op) {
0 commit comments