Skip to content

Commit f8eceb4

Browse files
authored
[MLIR] [Python] align python ir printing with mlir-print-ir-after-all (#107522)
When using the `enable_ir_printing` API from Python, it invokes IR printing with default args, printing the IR before each pass and printing IR after pass only if there have been changes. This PR attempts to align the `enable_ir_printing` API with the documentation
1 parent 8280651 commit f8eceb4

File tree

5 files changed

+66
-11
lines changed

5 files changed

+66
-11
lines changed

mlir/include/mlir-c/Pass.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,11 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager);
7474
MLIR_CAPI_EXPORTED MlirLogicalResult
7575
mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
7676

77-
/// Enable mlir-print-ir-after-all.
78-
MLIR_CAPI_EXPORTED void
79-
mlirPassManagerEnableIRPrinting(MlirPassManager passManager);
77+
/// Enable IR printing.
78+
MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
79+
MlirPassManager passManager, bool printBeforeAll, bool printAfterAll,
80+
bool printModuleScope, bool printAfterOnlyOnChange,
81+
bool printAfterOnlyOnFailure);
8082

8183
/// Enable / disable verify-each.
8284
MLIR_CAPI_EXPORTED void

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,17 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
7474
"Releases (leaks) the backing pass manager (testing)")
7575
.def(
7676
"enable_ir_printing",
77-
[](PyPassManager &passManager) {
78-
mlirPassManagerEnableIRPrinting(passManager.get());
77+
[](PyPassManager &passManager, bool printBeforeAll,
78+
bool printAfterAll, bool printModuleScope, bool printAfterChange,
79+
bool printAfterFailure) {
80+
mlirPassManagerEnableIRPrinting(
81+
passManager.get(), printBeforeAll, printAfterAll,
82+
printModuleScope, printAfterChange, printAfterFailure);
7983
},
80-
"Enable mlir-print-ir-after-all.")
84+
"print_before_all"_a = false, "print_after_all"_a = true,
85+
"print_module_scope"_a = false, "print_after_change"_a = false,
86+
"print_after_failure"_a = false,
87+
"Enable IR printing, default as mlir-print-ir-after-all.")
8188
.def(
8289
"enable_verifier",
8390
[](PyPassManager &passManager, bool enable) {

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,21 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
4444
return wrap(unwrap(passManager)->run(unwrap(op)));
4545
}
4646

47-
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
48-
return unwrap(passManager)->enableIRPrinting();
47+
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
48+
bool printBeforeAll, bool printAfterAll,
49+
bool printModuleScope,
50+
bool printAfterOnlyOnChange,
51+
bool printAfterOnlyOnFailure) {
52+
auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
53+
return printBeforeAll;
54+
};
55+
auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
56+
return printAfterAll;
57+
};
58+
return unwrap(passManager)
59+
->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
60+
printModuleScope, printAfterOnlyOnChange,
61+
printAfterOnlyOnFailure);
4962
}
5063

5164
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {

mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@ class PassManager:
1616
def __init__(self, context: Optional[_ir.Context] = None) -> None: ...
1717
def _CAPICreate(self) -> object: ...
1818
def _testing_release(self) -> None: ...
19-
def enable_ir_printing(self) -> None: ...
19+
def enable_ir_printing(
20+
self,
21+
print_before_all: bool = False,
22+
print_after_all: bool = True,
23+
print_module_scope: bool = False,
24+
print_after_change: bool = False,
25+
print_after_failure: bool = False,
26+
) -> None: ...
2027
def enable_verifier(self, enable: bool) -> None: ...
2128
@staticmethod
2229
def parse(pipeline: str, context: Optional[_ir.Context] = None) -> PassManager: ...

mlir/test/python/pass_manager.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,40 @@ def testPrintIrAfterAll():
300300
pm = PassManager.parse("builtin.module(canonicalize)")
301301
ctx.enable_multithreading(False)
302302
pm.enable_ir_printing()
303-
# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) ('builtin.module' operation) //----- //
303+
# CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- //
304+
# CHECK: module {
305+
# CHECK: func.func @main() {
306+
# CHECK: return
307+
# CHECK: }
308+
# CHECK: }
309+
pm.run(module)
310+
311+
312+
# CHECK-LABEL: TEST: testPrintIrBeforeAndAfterAll
313+
@run
314+
def testPrintIrBeforeAndAfterAll():
315+
with Context() as ctx:
316+
module = ModuleOp.parse(
317+
"""
318+
module {
319+
func.func @main() {
320+
%0 = arith.constant 10
321+
return
322+
}
323+
}
324+
"""
325+
)
326+
pm = PassManager.parse("builtin.module(canonicalize)")
327+
ctx.enable_multithreading(False)
328+
pm.enable_ir_printing(print_before_all=True, print_after_all=True)
329+
# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) //----- //
304330
# CHECK: module {
305331
# CHECK: func.func @main() {
306332
# CHECK: %[[C10:.*]] = arith.constant 10 : i64
307333
# CHECK: return
308334
# CHECK: }
309335
# CHECK: }
310-
# CHECK: // -----// IR Dump After Canonicalizer (canonicalize) ('builtin.module' operation) //----- //
336+
# CHECK: // -----// IR Dump After Canonicalizer (canonicalize) //----- //
311337
# CHECK: module {
312338
# CHECK: func.func @main() {
313339
# CHECK: return

0 commit comments

Comments
 (0)