Skip to content

Commit 7caf12d

Browse files
j2kunjoker-eph
andauthored
[mlir][core] Add an MLIR "pattern catalog" generator (#146228)
This PR adds a feature that attaches a listener to all RewritePatterns that logs information about the modified operations. When the MLIR test suite is run, these debug outputs can be filtered and combined into an index linking operations to the patterns that insert, modify, or replace them. This index is intended to be used to create a website that allows one to look up patterns from an operation name. The debug logs emitted can be viewed with --debug-only=generate-pattern-catalog, and the lit config is modified to do this when the env var MLIR_GENERATE_PATTERN_CATALOG is set. Example usage: ``` mkdir build && cd build cmake -G Ninja ../llvm \ -DLLVM_ENABLE_PROJECTS="mlir" \ -DLLVM_TARGETS_TO_BUILD="host" \ -DCMAKE_BUILD_TYPE=DEBUG ninja -j 24 check-mlir MLIR_GENERATE_PATTERN_CATALOG=1 bin/llvm-lit -j 24 -v -a tools/mlir/test | grep 'pattern-logging-listener' | sed 's/^# | [pattern-logging-listener] //g' | sort | uniq > pattern_catalog.txt ``` Sample pattern catalog output (that fits in a gist): https://gist.github.com/j2kun/02d1ab8d31c10d71027724984c89905a --------- Co-authored-by: Jeremy Kun <[email protected]> Co-authored-by: Mehdi Amini <[email protected]>
1 parent a7f595e commit 7caf12d

File tree

6 files changed

+112
-2
lines changed

6 files changed

+112
-2
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,25 @@ class RewriterBase : public OpBuilder {
475475
RewriterBase::Listener *rewriteListener;
476476
};
477477

478+
/// A listener that logs notification events to llvm::dbgs() before
479+
/// forwarding to the base listener.
480+
struct PatternLoggingListener : public RewriterBase::ForwardingListener {
481+
PatternLoggingListener(OpBuilder::Listener *listener, StringRef patternName)
482+
: RewriterBase::ForwardingListener(listener), patternName(patternName) {
483+
}
484+
485+
void notifyOperationInserted(Operation *op, InsertPoint previous) override;
486+
void notifyOperationModified(Operation *op) override;
487+
void notifyOperationReplaced(Operation *op, Operation *newOp) override;
488+
void notifyOperationReplaced(Operation *op,
489+
ValueRange replacement) override;
490+
void notifyOperationErased(Operation *op) override;
491+
void notifyPatternBegin(const Pattern &pattern, Operation *op) override;
492+
493+
private:
494+
StringRef patternName;
495+
};
496+
478497
/// Move the blocks that belong to "region" before the given position in
479498
/// another region "parent". The two regions must be different. The caller
480499
/// is responsible for creating or updating the operation transferring flow

mlir/lib/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ add_mlir_library(MLIRIR
2929
ODSSupport.cpp
3030
Operation.cpp
3131
OperationSupport.cpp
32+
PatternLoggingListener.cpp
3233
PatternMatch.cpp
3334
Region.cpp
3435
RegionKindInterface.cpp
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include "mlir/IR/PatternMatch.h"
2+
#include "llvm/Support/Debug.h"
3+
4+
#define DEBUG_TYPE "pattern-logging-listener"
5+
#define DBGS() (llvm::dbgs() << "[" << DEBUG_TYPE << "] ")
6+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
7+
8+
using namespace mlir;
9+
10+
void RewriterBase::PatternLoggingListener::notifyOperationInserted(
11+
Operation *op, InsertPoint previous) {
12+
LDBG(patternName << " | notifyOperationInserted"
13+
<< " | " << op->getName());
14+
ForwardingListener::notifyOperationInserted(op, previous);
15+
}
16+
17+
void RewriterBase::PatternLoggingListener::notifyOperationModified(
18+
Operation *op) {
19+
LDBG(patternName << " | notifyOperationModified"
20+
<< " | " << op->getName());
21+
ForwardingListener::notifyOperationModified(op);
22+
}
23+
24+
void RewriterBase::PatternLoggingListener::notifyOperationReplaced(
25+
Operation *op, Operation *newOp) {
26+
LDBG(patternName << " | notifyOperationReplaced (with op)"
27+
<< " | " << op->getName() << " | " << newOp->getName());
28+
ForwardingListener::notifyOperationReplaced(op, newOp);
29+
}
30+
31+
void RewriterBase::PatternLoggingListener::notifyOperationReplaced(
32+
Operation *op, ValueRange replacement) {
33+
LDBG(patternName << " | notifyOperationReplaced (with values)"
34+
<< " | " << op->getName());
35+
ForwardingListener::notifyOperationReplaced(op, replacement);
36+
}
37+
38+
void RewriterBase::PatternLoggingListener::notifyOperationErased(
39+
Operation *op) {
40+
LDBG(patternName << " | notifyOperationErased"
41+
<< " | " << op->getName());
42+
ForwardingListener::notifyOperationErased(op);
43+
}
44+
45+
void RewriterBase::PatternLoggingListener::notifyPatternBegin(
46+
const Pattern &pattern, Operation *op) {
47+
LDBG(patternName << " | notifyPatternBegin"
48+
<< " | " << op->getName());
49+
ForwardingListener::notifyPatternBegin(pattern, op);
50+
}

mlir/lib/Rewrite/PatternApplicator.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
#include "ByteCode.h"
1616
#include "llvm/Support/Debug.h"
1717

18+
#ifndef NDEBUG
19+
#include "llvm/ADT/ScopeExit.h"
20+
#endif
21+
1822
#define DEBUG_TYPE "pattern-application"
1923

2024
using namespace mlir;
@@ -206,11 +210,19 @@ LogicalResult PatternApplicator::matchAndRewrite(
206210
} else {
207211
LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
208212
<< bestPattern->getDebugName() << "\"\n");
209-
210213
const auto *pattern =
211214
static_cast<const RewritePattern *>(bestPattern);
212-
result = pattern->matchAndRewrite(op, rewriter);
213215

216+
#ifndef NDEBUG
217+
OpBuilder::Listener *oldListener = rewriter.getListener();
218+
auto loggingListener =
219+
std::make_unique<RewriterBase::PatternLoggingListener>(
220+
oldListener, pattern->getDebugName());
221+
rewriter.setListener(loggingListener.get());
222+
auto resetListenerCallback = llvm::make_scope_exit(
223+
[&] { rewriter.setListener(oldListener); });
224+
#endif
225+
result = pattern->matchAndRewrite(op, rewriter);
214226
LLVM_DEBUG(llvm::dbgs()
215227
<< "\"" << bestPattern->getDebugName() << "\" result "
216228
<< succeeded(result) << "\n");
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: mlir-opt %s --test-walk-pattern-rewrite-driver \
2+
// RUN: --allow-unregistered-dialect --debug-only=pattern-logging-listener 2>&1 | FileCheck %s
3+
4+
// Check that when replacing an op with a new op, we get appropriate
5+
// pattern-logging lines. The regex is because the anonymous namespace is
6+
// printed differently on different platforms.
7+
8+
// CHECK: [pattern-logging-listener] {{.anonymous.namespace.}}::ReplaceWithNewOp | notifyOperationInserted | test.new_op
9+
// CHECK: [pattern-logging-listener] {{.anonymous.namespace.}}::ReplaceWithNewOp | notifyOperationReplaced (with values) | test.replace_with_new_op
10+
// CHECK: [pattern-logging-listener] {{.anonymous.namespace.}}::ReplaceWithNewOp | notifyOperationModified | arith.addi
11+
// CHECK: [pattern-logging-listener] {{.anonymous.namespace.}}::ReplaceWithNewOp | notifyOperationModified | arith.addi
12+
// CHECK: [pattern-logging-listener] {{.anonymous.namespace.}}::ReplaceWithNewOp | notifyOperationErased | test.replace_with_new_op
13+
func.func @replace_with_new_op() -> i32 {
14+
%a = "test.replace_with_new_op"() : () -> (i32)
15+
%res = arith.addi %a, %a : i32
16+
return %res : i32
17+
}

mlir/test/lit.cfg.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,17 @@ def find_real_python_interpreter():
301301
ToolSubst("mlir-opt", "mlir-opt --verify-roundtrip", unresolved="fatal"),
302302
]
303303
)
304+
elif "MLIR_GENERATE_PATTERN_CATALOG" in os.environ:
305+
tools.extend(
306+
[
307+
ToolSubst(
308+
"mlir-opt",
309+
"mlir-opt --debug-only=pattern-logging-listener --mlir-disable-threading",
310+
unresolved="fatal",
311+
),
312+
ToolSubst("FileCheck", "FileCheck --dump-input=always", unresolved="fatal"),
313+
]
314+
)
304315
else:
305316
tools.extend(["mlir-opt"])
306317

0 commit comments

Comments
 (0)