Skip to content

Commit a83a7e7

Browse files
author
git apple-llvm automerger
committed
Merge commit '0c50971460cb' from llvm.org/main into next
2 parents 89c3f42 + 0c50971 commit a83a7e7

File tree

2 files changed

+79
-5
lines changed

2 files changed

+79
-5
lines changed

mlir/lib/Transforms/SymbolDCE.cpp

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Transforms/Passes.h"
1515

1616
#include "mlir/IR/SymbolTable.h"
17+
#include "llvm/Support/Debug.h"
1718

1819
namespace mlir {
1920
#define GEN_PASS_DEF_SYMBOLDCE
@@ -22,6 +23,8 @@ namespace mlir {
2223

2324
using namespace mlir;
2425

26+
#define DEBUG_TYPE "symbol-dce"
27+
2528
namespace {
2629
struct SymbolDCE : public impl::SymbolDCEBase<SymbolDCE> {
2730
void runOnOperation() override;
@@ -84,6 +87,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
8487
SymbolTableCollection &symbolTable,
8588
bool symbolTableIsHidden,
8689
DenseSet<Operation *> &liveSymbols) {
90+
LLVM_DEBUG(llvm::dbgs() << "computeLiveness: " << symbolTableOp->getName()
91+
<< "\n");
8792
// A worklist of live operations to propagate uses from.
8893
SmallVector<Operation *, 16> worklist;
8994

@@ -105,36 +110,70 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
105110
}
106111

107112
// Process the set of symbols that were known to be live, adding new symbols
108-
// that are referenced within.
113+
// that are referenced within. For operations that are not symbol tables, it
114+
// considers the liveness with respect to the op itself rather than scope of
115+
// nested symbol tables by enqueuing all the top level operations for
116+
// consideration.
109117
while (!worklist.empty()) {
110118
Operation *op = worklist.pop_back_val();
119+
LLVM_DEBUG(llvm::dbgs() << "processing: " << op->getName() << "\n");
111120

112121
// If this is a symbol table, recursively compute its liveness.
113122
if (op->hasTrait<OpTrait::SymbolTable>()) {
114123
// The internal symbol table is hidden if the parent is, if its not a
115124
// symbol, or if it is a private symbol.
116125
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
117126
bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
127+
LLVM_DEBUG(llvm::dbgs() << "\tsymbol table: " << op->getName()
128+
<< " is hidden: " << symIsHidden << "\n");
118129
if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
119130
return failure();
131+
} else {
132+
LLVM_DEBUG(llvm::dbgs()
133+
<< "\tnon-symbol table: " << op->getName() << "\n");
134+
// If the op is not a symbol table, then, unless op itself is dead which
135+
// would be handled by DCE, we need to check all the regions and blocks
136+
// within the op to find the uses (e.g., consider visibility within op as
137+
// if top level rather than relying on pure symbol table visibility). This
138+
// is more conservative than SymbolTable::walkSymbolTables in the case
139+
// where there is again SymbolTable information to take advantage of.
140+
for (auto &region : op->getRegions())
141+
for (auto &block : region.getBlocks())
142+
for (Operation &op : block)
143+
if (op.getNumRegions())
144+
worklist.push_back(&op);
120145
}
121146

147+
// Get the first parent symbol table op. Note: due to enqueueing of
148+
// top-level ops, we may not have a symbol table parent here, but if we do
149+
// not, then we also don't have a symbol.
150+
Operation *parentOp = op->getParentOp();
151+
if (!parentOp->hasTrait<OpTrait::SymbolTable>())
152+
continue;
153+
122154
// Collect the uses held by this operation.
123155
std::optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(op);
124156
if (!uses) {
125157
return op->emitError()
126-
<< "operation contains potentially unknown symbol table, "
127-
"meaning that we can't reliable compute symbol uses";
158+
<< "operation contains potentially unknown symbol table, meaning "
159+
<< "that we can't reliable compute symbol uses";
128160
}
129161

130162
SmallVector<Operation *, 4> resolvedSymbols;
163+
LLVM_DEBUG(llvm::dbgs() << "uses of " << op->getName() << "\n");
131164
for (const SymbolTable::SymbolUse &use : *uses) {
165+
LLVM_DEBUG(llvm::dbgs() << "\tuse: " << use.getUser() << "\n");
132166
// Lookup the symbols referenced by this use.
133167
resolvedSymbols.clear();
134-
if (failed(symbolTable.lookupSymbolIn(
135-
op->getParentOp(), use.getSymbolRef(), resolvedSymbols)))
168+
if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(),
169+
resolvedSymbols)))
136170
// Ignore references to unknown symbols.
137171
continue;
172+
LLVM_DEBUG({
173+
llvm::dbgs() << "\t\tresolved symbols: ";
174+
llvm::interleaveComma(resolvedSymbols, llvm::dbgs());
175+
llvm::dbgs() << "\n";
176+
});
138177

139178
// Mark each of the resolved symbols as live.
140179
for (Operation *resolvedSymbol : resolvedSymbols)

mlir/test/Transforms/test-symbol-dce.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,38 @@ module {
9898
// CHECK: "live.user"() {uses = [@unknown_symbol]} : () -> ()
9999
"live.user"() {uses = [@unknown_symbol]} : () -> ()
100100
}
101+
102+
// -----
103+
104+
// Check that we don't DCE nested symbols if they are nested inside region
105+
// without SymbolTable.
106+
107+
// CHECK-LABEL: module attributes {test.nested_nosymboltable_region}
108+
module attributes { test.nested_nosymboltable_region } {
109+
"test.one_region_op"() ({
110+
"test.symbol_scope"() ({
111+
// CHECK: func nested @nfunction
112+
func.func nested @nfunction() {
113+
return
114+
}
115+
func.call @nfunction() : () -> ()
116+
"test.finish"() : () -> ()
117+
}) : () -> ()
118+
"test.finish"() : () -> ()
119+
}) : () -> ()
120+
}
121+
122+
// -----
123+
124+
// CHECK-LABEL: module attributes {test.nested_nosymboltable_region_notcalled}
125+
// CHECK-NOT: @nested
126+
// CHECK: @main
127+
module attributes { test.nested_nosymboltable_region_notcalled } {
128+
"test.one_region_op"() ({
129+
module {
130+
func.func nested @nested() { return }
131+
func.func @main() { return }
132+
}
133+
"test.finish"() : () -> ()
134+
}) : () -> ()
135+
}

0 commit comments

Comments
 (0)