From c265cdae97a9449f6fae2d26db79088aec11e8cc Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sun, 7 May 2023 06:29:46 +0800 Subject: [PATCH] [TVMScript] Add `__name__` attr for parsed PrimFunc and IRModule (#14786) This PR adds `__name__` attr to indicate the func/mod name for parsed PrimFunc and IRModule. --- python/tvm/script/parser/ir/entry.py | 4 +++- python/tvm/script/parser/tir/entry.py | 4 +++- .../python/unittest/test_tvmscript_parser_ir.py | 1 + .../python/unittest/test_tvmscript_parser_tir.py | 16 ++++++++++++++-- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/python/tvm/script/parser/ir/entry.py b/python/tvm/script/parser/ir/entry.py index 94fc3d2e2c7e..5878a1ce55cc 100644 --- a/python/tvm/script/parser/ir/entry.py +++ b/python/tvm/script/parser/ir/entry.py @@ -40,7 +40,9 @@ def ir_module(mod: Type) -> IRModule: if not inspect.isclass(mod): raise TypeError(f"Expect a class, but got: {mod}") - return parse(mod, utils.inspect_class_capture(mod)) + m = parse(mod, utils.inspect_class_capture(mod)) + setattr(m, "__name__", mod.__name__) + return m setattr(ir_module, "dispatch_token", "ir") diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 649f817411f0..d5bff7a856d5 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -42,7 +42,9 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]: raise TypeError(f"Expect a function, but got: {func}") if utils.is_defined_in_class(inspect.stack(), func): return func - return parse(func, utils.inspect_function_capture(func)) + f = parse(func, utils.inspect_function_capture(func)) + setattr(f, "__name__", func.__name__) + return f setattr(prim_func, "dispatch_token", "tir") diff --git a/tests/python/unittest/test_tvmscript_parser_ir.py b/tests/python/unittest/test_tvmscript_parser_ir.py index d3e758fbe1a0..d33594794f0c 100644 --- a/tests/python/unittest/test_tvmscript_parser_ir.py +++ b/tests/python/unittest/test_tvmscript_parser_ir.py @@ -29,6 +29,7 @@ class BlankIRModule: pass assert isinstance(BlankIRModule, IRModule) and len(BlankIRModule.functions.items()) == 0 + assert BlankIRModule.__name__ == "BlankIRModule" if __name__ == "__main__": diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index 20be6d149808..31bf5cc10180 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -16,8 +16,6 @@ # under the License. """Unittests for tvm.script.parser.tir""" -import pytest -import inspect import tvm.testing from tvm.script.parser import tir as T from tvm import ir, tir @@ -59,5 +57,19 @@ def test_tir_ptr_proxy(): ) +def test_tir_func_name(): + @T.prim_func + def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + assert matmul.__name__ == "matmul" + + if __name__ == "__main__": tvm.testing.main()