diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp index 208bab52284a3..6d21706570bbe 100644 --- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp +++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp @@ -19,6 +19,7 @@ //===----------------------------------------------------------------------===// #include "NVPTX.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/ConstantFolding.h" @@ -59,7 +60,10 @@ class NVVMReflect { StringMap ReflectMap; bool handleReflectFunction(Module &M, StringRef ReflectName); void populateReflectMap(Module &M); - void foldReflectCall(CallInst *Call, Constant *NewValue); + void replaceReflectCalls( + SmallVector, 8> &ReflectReplacements, + const DataLayout &DL); + SetVector findTransitivelyDeadBlocks(BasicBlock *DeadBB); public: // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module @@ -87,11 +91,6 @@ static cl::opt NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden, cl::desc("NVVM reflection, enabled by default")); -char NVVMReflectLegacyPass::ID = 0; -INITIALIZE_PASS(NVVMReflectLegacyPass, "nvvm-reflect", - "Replace occurrences of __nvvm_reflect() calls with 0/1", false, - false) - // Allow users to specify additional key/value pairs to reflect. These key/value // pairs are the last to be added to the ReflectMap, and therefore will take // precedence over initial values (i.e. __CUDA_FTZ from module medadata and @@ -101,6 +100,15 @@ static cl::list ReflectList( cl::desc("A key=value pair. Replace __nvvm_reflect(name) with value."), cl::ValueRequired); +static cl::opt NVVMReflectDCE( + "nvvm-reflect-dce", cl::init(false), cl::Hidden, + cl::desc("Delete dead blocks introduced by reflect call elimination")); + +char NVVMReflectLegacyPass::ID = 0; +INITIALIZE_PASS(NVVMReflectLegacyPass, "nvvm-reflect", + "Replace occurrences of __nvvm_reflect() calls with 0/1", false, + false) + // Set the ReflectMap with, first, the value of __CUDA_FTZ from module metadata, // and then the key/value pairs from the command line. void NVVMReflect::populateReflectMap(Module &M) { @@ -138,6 +146,8 @@ bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) { assert(F->getReturnType()->isIntegerTy() && "_reflect's return type should be integer"); + SmallVector, 8> ReflectReplacements; + const bool Changed = !F->use_empty(); for (User *U : make_early_inc_range(F->users())) { // Reflect function calls look like: @@ -178,19 +188,44 @@ bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) { << "(" << ReflectArg << ") with value " << ReflectVal << "\n"); auto *NewValue = ConstantInt::get(Call->getType(), ReflectVal); - foldReflectCall(Call, NewValue); - Call->eraseFromParent(); + dbgs() << "NewValue: " << *NewValue << "\n"; + dbgs() << "Call: " << *Call << "\n"; + ReflectReplacements.push_back({Call, NewValue}); } - // Remove the __nvvm_reflect function from the module + replaceReflectCalls(ReflectReplacements, M.getDataLayout()); F->eraseFromParent(); return Changed; } -void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) { +/// Find all blocks that become dead transitively from an initial dead block. +/// Returns the complete set including the original dead block and any blocks +/// that lose all their predecessors due to the deletion cascade. +SetVector +NVVMReflect::findTransitivelyDeadBlocks(BasicBlock *DeadBB) { + SmallVector Worklist({DeadBB}); + SetVector DeadBlocks; + while (!Worklist.empty()) { + auto *BB = Worklist.pop_back_val(); + DeadBlocks.insert(BB); + + for (BasicBlock *Succ : successors(BB)) + if (pred_size(Succ) == 1 && DeadBlocks.insert(Succ)) + Worklist.push_back(Succ); + } + return DeadBlocks; +} + +/// Replace calls to __nvvm_reflect with corresponding constant values. Then +/// clean up through constant folding and propagation and dead block +/// elimination, if NVVMReflectDCE is enabled. +void NVVMReflect::replaceReflectCalls( + SmallVector, 8> &ReflectReplacements, + const DataLayout &DL) { SmallVector Worklist; - // Replace an instruction with a constant and add all users of the instruction - // to the worklist + SetVector DeadBlocks; + + // Replace an instruction with a constant and add all users to the worklist auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) { for (auto *U : I->users()) if (auto *UI = dyn_cast(U)) @@ -198,18 +233,62 @@ void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) { I->replaceAllUsesWith(C); }; - ReplaceInstructionWithConst(Call, NewValue); + for (auto &[Call, NewValue] : ReflectReplacements) { + ReplaceInstructionWithConst(Call, NewValue); + Call->eraseFromParent(); + } - auto &DL = Call->getModule()->getDataLayout(); - while (!Worklist.empty()) { - auto *I = Worklist.pop_back_val(); - if (auto *C = ConstantFoldInstruction(I, DL)) { - ReplaceInstructionWithConst(I, C); - if (isInstructionTriviallyDead(I)) - I->eraseFromParent(); - } else if (I->isTerminator()) { - ConstantFoldTerminator(I->getParent()); + // Constant fold reflect results. If NVVMReflectDCE is enabled, we will + // alternate between constant folding/propagation and dead block elimination. + // Terminator folding may create new dead blocks. When those dead blocks are + // deleted, their live successors may have PHIs that can be simplified, which + // may yield more work for folding/propagation. + while (true) { + // Iterate folding and propagating constants until the worklist is empty. + while (!Worklist.empty()) { + auto *I = Worklist.pop_back_val(); + if (auto *C = ConstantFoldInstruction(I, DL)) { + ReplaceInstructionWithConst(I, C); + if (isInstructionTriviallyDead(I)) + I->eraseFromParent(); + } else if (I->isTerminator()) { + BasicBlock *BB = I->getParent(); + SmallVector Succs(successors(BB)); + // Some blocks may become dead if the terminator is folded because + // a conditional branch is turned into a direct branch. Add those dead + // blocks to the dead blocks set if NVVMReflectDCE is enabled. + if (ConstantFoldTerminator(BB)) { + for (BasicBlock *Succ : Succs) { + if (pred_empty(Succ) && + Succ != &Succ->getParent()->getEntryBlock() && NVVMReflectDCE) { + SetVector TransitivelyDead = + findTransitivelyDeadBlocks(Succ); + DeadBlocks.insert(TransitivelyDead.begin(), + TransitivelyDead.end()); + } + } + } + } } + // No more constants to fold and no more dead blocks + // to create more work. We're done. + if (DeadBlocks.empty()) + break; + // PHI nodes of live successors of dead blocks get eliminated when the dead + // blocks are eliminated. Their users can now be simplified further, so add + // them to the worklist. + for (BasicBlock *DeadBB : DeadBlocks) + for (BasicBlock *Succ : successors(DeadBB)) + if (!DeadBlocks.contains(Succ)) + for (PHINode &PHI : Succ->phis()) + for (auto *U : PHI.users()) + if (auto *UI = dyn_cast(U)) + Worklist.push_back(UI); + // Delete all dead blocks + for (BasicBlock *DeadBB : DeadBlocks) + DeleteDeadBlock(DeadBB); + + DeadBlocks.clear(); } } diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll index 19c74df303702..553b2c107d86a 100644 --- a/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll +++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll @@ -3,12 +3,12 @@ ; RUN: cat %s > %t.noftz ; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 0}' >> %t.noftz -; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \ +; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' -nvvm-reflect-dce \ ; RUN: | FileCheck %s --check-prefix=USE_FTZ_0 --check-prefix=CHECK ; RUN: cat %s > %t.ftz ; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.ftz -; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \ +; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' -nvvm-reflect-dce \ ; RUN: | FileCheck %s --check-prefix=USE_FTZ_1 --check-prefix=CHECK @str = private unnamed_addr addrspace(4) constant [11 x i8] c"__CUDA_FTZ\00" diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll index 244b44fea9b83..86cdc3f489c2e 100644 --- a/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll +++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll @@ -3,12 +3,12 @@ ; RUN: cat %s > %t.noftz ; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 0}' >> %t.noftz -; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \ +; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' -nvvm-reflect-dce \ ; RUN: | FileCheck %s --check-prefix=USE_FTZ_0 --check-prefix=CHECK ; RUN: cat %s > %t.ftz ; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.ftz -; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \ +; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' -nvvm-reflect-dce \ ; RUN: | FileCheck %s --check-prefix=USE_FTZ_1 --check-prefix=CHECK @str = private unnamed_addr addrspace(4) constant [11 x i8] c"__CUDA_FTZ\00"