Skip to content

Commit

Permalink
[TVMScript] Add __name__ attr for parsed PrimFunc and IRModule (apa…
Browse files Browse the repository at this point in the history
…che#14786)

This PR adds `__name__` attr to indicate the func/mod name for parsed PrimFunc and IRModule.
  • Loading branch information
Hzfengsy authored May 6, 2023
1 parent 571eff9 commit c265cda
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
4 changes: 3 additions & 1 deletion python/tvm/script/parser/ir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def ir_module(mod: Type) -> IRModule:
if not inspect.isclass(mod):
raise TypeError(f"Expect a class, but got: {mod}")

return parse(mod, utils.inspect_class_capture(mod))
m = parse(mod, utils.inspect_class_capture(mod))
setattr(m, "__name__", mod.__name__)
return m


setattr(ir_module, "dispatch_token", "ir")
4 changes: 3 additions & 1 deletion python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
raise TypeError(f"Expect a function, but got: {func}")
if utils.is_defined_in_class(inspect.stack(), func):
return func
return parse(func, utils.inspect_function_capture(func))
f = parse(func, utils.inspect_function_capture(func))
setattr(f, "__name__", func.__name__)
return f


setattr(prim_func, "dispatch_token", "tir")
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_tvmscript_parser_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class BlankIRModule:
pass

assert isinstance(BlankIRModule, IRModule) and len(BlankIRModule.functions.items()) == 0
assert BlankIRModule.__name__ == "BlankIRModule"


if __name__ == "__main__":
Expand Down
16 changes: 14 additions & 2 deletions tests/python/unittest/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# under the License.
"""Unittests for tvm.script.parser.tir"""

import pytest
import inspect
import tvm.testing
from tvm.script.parser import tir as T
from tvm import ir, tir
Expand Down Expand Up @@ -59,5 +57,19 @@ def test_tir_ptr_proxy():
)


def test_tir_func_name():
@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

assert matmul.__name__ == "matmul"


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit c265cda

Please sign in to comment.