diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 9e6c100c954d..5315c0f6755e 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. """The entry point of TVM parser.""" - +import inspect from typing import Any, Dict, Union from ...ir_builder import IRBuilder from . import doc from .diagnostics import Source +from .error import ParserError from .parser import Parser @@ -53,8 +54,19 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) "tir": tir, } + ann = {} + if inspect.isfunction(program): + ann = {program.__name__: program.__annotations__} + elif inspect.isclass(program): + for name, func in program.__dict__.items(): + if inspect.isfunction(func): + ann[name] = func.__annotations__ + source = Source(program) - parser = Parser(source) + parser = Parser(source, ann) with IRBuilder() as builder: - parser.parse(extra_vars=extra_vars) + try: + parser.parse(extra_vars=extra_vars) + except ParserError as err: + parser.report_error(err.node, err.args[0]) return builder.get() diff --git a/python/tvm/script/parser/core/error.py b/python/tvm/script/parser/core/error.py new file mode 100644 index 000000000000..6c10fe83d6a5 --- /dev/null +++ b/python/tvm/script/parser/core/error.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Error classes for diagnostics.""" +from . import doc + + +class ParserError(Exception): + """Error class for diagnostics.""" + + def __init__(self, node: doc.AST, msg: str): + super().__init__(msg) + self.node = node diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 075aedd89146..96901c522d22 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union from . import dispatch, doc +from .error import ParserError if TYPE_CHECKING: from .parser import Parser @@ -69,7 +70,7 @@ class ExprEvaluator: The value table for expression evaluation. new_value_count : int - The count for ntermediate result added during evaluation. + The count for intermediate result added during evaluation. """ parser: "Parser" @@ -106,7 +107,7 @@ def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any: result = self._visit(node) # pylint: disable=protected-access if isinstance(result, doc.Name): if result.id not in self.value_table: - self.parser.report_error(result, f"Undefined variable: {result.id}") + raise ParserError(result, f"Undefined variable: {result.id}") return self.value_table[result.id] if isinstance(result, doc.Constant): return result.value @@ -164,7 +165,7 @@ def _visit(self, node: doc.AST) -> Any: assert isinstance(node, doc.AST) if isinstance(node, doc.Name): if node.id not in self.value_table: - self.parser.report_error(node, f"Undefined variable: {node.id}") + raise ParserError(node, f"Undefined variable: {node.id}") return node if isinstance( node, diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 837b7cce5d5e..72858a202853 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -236,11 +236,17 @@ class Parser(doc.NodeVisitor): diag: Diagnostics dispatch_tokens: List[str] + function_annotations: Optional[Dict[str, Dict[str, Any]]] var_table: VarTable - def __init__(self, source: Source) -> None: + def __init__( + self, + source: Source, + function_annotations: Dict[str, Dict[str, Any]], + ) -> None: self.diag = Diagnostics(source) self.dispatch_tokens = ["default"] + self.function_annotations = function_annotations self.var_table = VarTable() def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index ea26c4740a46..0a489a8f0401 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -338,6 +338,9 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: node : doc.FunctionDef The doc AST function definition node. """ + supplied_annotation = self.function_annotations + func_annotation = supplied_annotation.get(node.name, {}) + self.function_annotations = None with self.var_table.with_frame(): self.var_table.add("range", T.serial) with T.prim_func(): @@ -348,35 +351,28 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: ret_type = PrimType(ret_type().dtype) T.func_ret(ret_type) with self.with_dispatch_token("tir"): - self.visit(node.args) + # TODO: handle different types of arguments: + # - vararg: arg | None + # - kwonlyargs: list[arg] + # - kw_defaults: list[expr | None] + # - kwarg: arg | None + # - defaults: list[expr] + # - posonlyargs: list[arg] + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation required for function parameters.") + try: + ann = self.eval_expr(arg.annotation) + if callable(ann): + ann = ann() + except Exception: # pylint: disable=broad-except + ann = func_annotation.get(arg.arg, None) + if ann is None: + raise + param = T.arg(arg.arg, ann) + self.var_table.add(arg.arg, param) self.visit_body(node.body) - - -@dispatch.register(token="tir", type_name="arguments") -def visit_arguments(self: Parser, node: doc.arguments) -> None: - """The arguments visiting method for tir. - - Parameters - ---------- - self : Parser - The visiting parser. - - node : doc.arguments - The doc AST arguments node. - """ - # TODO: handle different types of arguments: - # - vararg: arg | None - # - kwonlyargs: list[arg] - # - kw_defaults: list[expr | None] - # - kwarg: arg | None - # - defaults: list[expr] - # - posonlyargs: list[arg] - arg: doc.arg - for arg in node.args: - if arg.annotation is None: - self.report_error(arg, "Type annotation is required for function parameters.") - param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation)) - self.var_table.add(arg.arg, param) + self.function_annotations = supplied_annotation @dispatch.register(token="tir", type_name="tvm_annotation") diff --git a/tests/python/unittest/test_tvmscript_meta_programming.py b/tests/python/unittest/test_tvmscript_meta_programming.py index 2473c0c84564..ed567b659444 100644 --- a/tests/python/unittest/test_tvmscript_meta_programming.py +++ b/tests/python/unittest/test_tvmscript_meta_programming.py @@ -19,41 +19,66 @@ from tvm.script import tir as T -def matmul_generator(M: int, N: int, K: int, dtype: str): +def test_meta_programming_matmul(): + def matmul_generator(M: int, N: int, K: int, dtype: str): + @T.prim_func + def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [M, K], dtype=dtype) + B = T.match_buffer(b, [N, K], dtype=dtype) + C = T.match_buffer(c, [M, N], dtype=dtype) + + for i, j, k in T.grid(M, N, K): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + return matmul + @T.prim_func - def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [M, K], dtype=dtype) - B = T.match_buffer(b, [N, K], dtype=dtype) - C = T.match_buffer(c, [M, N], dtype=dtype) + def matmul_128_128_128_fp16(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float16") + B = T.match_buffer(b, [128, 128], dtype="float16") + C = T.match_buffer(c, [128, 128], dtype="float16") - for i, j, k in T.grid(M, N, K): + for i, j, k in T.grid(128, 128, 128): with T.block(): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] - return matmul + f = matmul_generator(128, 128, 128, "float16") + tvm.ir.assert_structural_equal(f, matmul_128_128_128_fp16) + +def test_meta_programming_uncaptured_var(): + def generate_erf(dtype): + @T.prim_func + def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): + for i in range(1): + with T.block("C"): + C[i] = T.erf(A[i]) -@T.prim_func -def matmul_128_128_128_fp16(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128], dtype="float16") - B = T.match_buffer(b, [128, 128], dtype="float16") - C = T.match_buffer(c, [128, 128], dtype="float16") + return main - for i, j, k in T.grid(128, 128, 128): - with T.block(): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - C[vi, vj] = T.float32(0) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + @T.prim_func + def fp32(A: T.Buffer((1,), "float32"), C: T.Buffer((1,), "float32")): + for i in range(1): + with T.block("C"): + C[i] = T.erf(A[i]) + @T.prim_func + def fp16(A: T.Buffer((1,), "float16"), C: T.Buffer((1,), "float16")): + for i in range(1): + with T.block("C"): + C[i] = T.erf(A[i]) -def test_meta_programming_matmul(): - f = matmul_generator(128, 128, 128, "float16") - tvm.ir.assert_structural_equal(f, matmul_128_128_128_fp16) + tvm.ir.assert_structural_equal(fp16, generate_erf("float16")) + tvm.ir.assert_structural_equal(fp32, generate_erf("float32")) if __name__ == "__main__": test_meta_programming_matmul() + test_meta_programming_uncaptured_var()