Skip to content

Commit

Permalink
[Bugfix][TVMScript] Capture fails if var appears only in annotation (a…
Browse files Browse the repository at this point in the history
…pache#14849)

Consider the case below:

```python
dtype = "float32"

@T.prim_func:
def f(
  A: T.Buffer((1, ), dtype),
  B: T.Buffer((1, ), dtype),
):
  ...
```

The variable `dtype` only appears in the type annotation of the function
being parsed. In this case, the python interpreter will evaluate the
annotation first before invoking the decorator, and thus if `dtype`
doesn't appear in the function body, it will not be considered as being
captured by the function itself. As a result, `inspect` module will be
unable to supply the value of `dtype` during parsing, leading to
failure.

This PR fixes the bug by maintaining a copy of function annotations
that are already parsed. Whenever expression evaluation fails during
parsing, it falls back to using the copy that is evaluated by python
interpreter.
  • Loading branch information
junrushao authored May 14, 2023
1 parent 318f894 commit 9f0c642
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 56 deletions.
18 changes: 15 additions & 3 deletions python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""The entry point of TVM parser."""

import inspect
from typing import Any, Dict, Union

from ...ir_builder import IRBuilder
from . import doc
from .diagnostics import Source
from .error import ParserError
from .parser import Parser


Expand Down Expand Up @@ -53,8 +54,19 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None)
"tir": tir,
}

ann = {}
if inspect.isfunction(program):
ann = {program.__name__: program.__annotations__}
elif inspect.isclass(program):
for name, func in program.__dict__.items():
if inspect.isfunction(func):
ann[name] = func.__annotations__

source = Source(program)
parser = Parser(source)
parser = Parser(source, ann)
with IRBuilder() as builder:
parser.parse(extra_vars=extra_vars)
try:
parser.parse(extra_vars=extra_vars)
except ParserError as err:
parser.report_error(err.node, err.args[0])
return builder.get()
26 changes: 26 additions & 0 deletions python/tvm/script/parser/core/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Error classes for diagnostics."""
from . import doc


class ParserError(Exception):
"""Error class for diagnostics."""

def __init__(self, node: doc.AST, msg: str):
super().__init__(msg)
self.node = node
7 changes: 4 additions & 3 deletions python/tvm/script/parser/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union

from . import dispatch, doc
from .error import ParserError

if TYPE_CHECKING:
from .parser import Parser
Expand Down Expand Up @@ -69,7 +70,7 @@ class ExprEvaluator:
The value table for expression evaluation.
new_value_count : int
The count for ntermediate result added during evaluation.
The count for intermediate result added during evaluation.
"""

parser: "Parser"
Expand Down Expand Up @@ -106,7 +107,7 @@ def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any:
result = self._visit(node) # pylint: disable=protected-access
if isinstance(result, doc.Name):
if result.id not in self.value_table:
self.parser.report_error(result, f"Undefined variable: {result.id}")
raise ParserError(result, f"Undefined variable: {result.id}")
return self.value_table[result.id]
if isinstance(result, doc.Constant):
return result.value
Expand Down Expand Up @@ -164,7 +165,7 @@ def _visit(self, node: doc.AST) -> Any:
assert isinstance(node, doc.AST)
if isinstance(node, doc.Name):
if node.id not in self.value_table:
self.parser.report_error(node, f"Undefined variable: {node.id}")
raise ParserError(node, f"Undefined variable: {node.id}")
return node
if isinstance(
node,
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,17 @@ class Parser(doc.NodeVisitor):

diag: Diagnostics
dispatch_tokens: List[str]
function_annotations: Optional[Dict[str, Dict[str, Any]]]
var_table: VarTable

def __init__(self, source: Source) -> None:
def __init__(
self,
source: Source,
function_annotations: Dict[str, Dict[str, Any]],
) -> None:
self.diag = Diagnostics(source)
self.dispatch_tokens = ["default"]
self.function_annotations = function_annotations
self.var_table = VarTable()

def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
Expand Down
52 changes: 24 additions & 28 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
node : doc.FunctionDef
The doc AST function definition node.
"""
supplied_annotation = self.function_annotations
func_annotation = supplied_annotation.get(node.name, {})
self.function_annotations = None
with self.var_table.with_frame():
self.var_table.add("range", T.serial)
with T.prim_func():
Expand All @@ -348,35 +351,28 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
ret_type = PrimType(ret_type().dtype)
T.func_ret(ret_type)
with self.with_dispatch_token("tir"):
self.visit(node.args)
# TODO: handle different types of arguments:
# - vararg: arg | None
# - kwonlyargs: list[arg]
# - kw_defaults: list[expr | None]
# - kwarg: arg | None
# - defaults: list[expr]
# - posonlyargs: list[arg]
for arg in node.args.args:
if arg.annotation is None:
self.report_error(arg, "Type annotation required for function parameters.")
try:
ann = self.eval_expr(arg.annotation)
if callable(ann):
ann = ann()
except Exception: # pylint: disable=broad-except
ann = func_annotation.get(arg.arg, None)
if ann is None:
raise
param = T.arg(arg.arg, ann)
self.var_table.add(arg.arg, param)
self.visit_body(node.body)


@dispatch.register(token="tir", type_name="arguments")
def visit_arguments(self: Parser, node: doc.arguments) -> None:
"""The arguments visiting method for tir.
Parameters
----------
self : Parser
The visiting parser.
node : doc.arguments
The doc AST arguments node.
"""
# TODO: handle different types of arguments:
# - vararg: arg | None
# - kwonlyargs: list[arg]
# - kw_defaults: list[expr | None]
# - kwarg: arg | None
# - defaults: list[expr]
# - posonlyargs: list[arg]
arg: doc.arg
for arg in node.args:
if arg.annotation is None:
self.report_error(arg, "Type annotation is required for function parameters.")
param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation))
self.var_table.add(arg.arg, param)
self.function_annotations = supplied_annotation


@dispatch.register(token="tir", type_name="tvm_annotation")
Expand Down
67 changes: 46 additions & 21 deletions tests/python/unittest/test_tvmscript_meta_programming.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,66 @@
from tvm.script import tir as T


def matmul_generator(M: int, N: int, K: int, dtype: str):
def test_meta_programming_matmul():
def matmul_generator(M: int, N: int, K: int, dtype: str):
@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [M, K], dtype=dtype)
B = T.match_buffer(b, [N, K], dtype=dtype)
C = T.match_buffer(c, [M, N], dtype=dtype)

for i, j, k in T.grid(M, N, K):
with T.block():
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

return matmul

@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [M, K], dtype=dtype)
B = T.match_buffer(b, [N, K], dtype=dtype)
C = T.match_buffer(c, [M, N], dtype=dtype)
def matmul_128_128_128_fp16(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128], dtype="float16")
B = T.match_buffer(b, [128, 128], dtype="float16")
C = T.match_buffer(c, [128, 128], dtype="float16")

for i, j, k in T.grid(M, N, K):
for i, j, k in T.grid(128, 128, 128):
with T.block():
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

return matmul
f = matmul_generator(128, 128, 128, "float16")
tvm.ir.assert_structural_equal(f, matmul_128_128_128_fp16)


def test_meta_programming_uncaptured_var():
def generate_erf(dtype):
@T.prim_func
def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)):
for i in range(1):
with T.block("C"):
C[i] = T.erf(A[i])

@T.prim_func
def matmul_128_128_128_fp16(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128], dtype="float16")
B = T.match_buffer(b, [128, 128], dtype="float16")
C = T.match_buffer(c, [128, 128], dtype="float16")
return main

for i, j, k in T.grid(128, 128, 128):
with T.block():
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def fp32(A: T.Buffer((1,), "float32"), C: T.Buffer((1,), "float32")):
for i in range(1):
with T.block("C"):
C[i] = T.erf(A[i])

@T.prim_func
def fp16(A: T.Buffer((1,), "float16"), C: T.Buffer((1,), "float16")):
for i in range(1):
with T.block("C"):
C[i] = T.erf(A[i])

def test_meta_programming_matmul():
f = matmul_generator(128, 128, 128, "float16")
tvm.ir.assert_structural_equal(f, matmul_128_128_128_fp16)
tvm.ir.assert_structural_equal(fp16, generate_erf("float16"))
tvm.ir.assert_structural_equal(fp32, generate_erf("float32"))


if __name__ == "__main__":
test_meta_programming_matmul()
test_meta_programming_uncaptured_var()

0 comments on commit 9f0c642

Please sign in to comment.