Skip to content

Commit

Permalink
Adds dtype constraint enforcement, various bug fixes and improvements
Browse files Browse the repository at this point in the history
- Adds a new `test_all_public_apis_verified` test that will ensure that all
     public APIs that accept or return tensor types indicate their data type
     constraints.

- Updates `test_dtype_constraints` to work with methods.

- Fixes a bug in `conf.py` that was preventing the doc generation from correctly
     checking for dtype constraint information on methods.

- Adds support for checking data types of sequence arguments.

- Renames `@constraints.dtype_info` to `@constraints.dtypes` and makes various
     argument names shorter.

- Updates `convert_inputs_to_tensors` to require argument names of arguments
     to convert to be explicitly specified instead of converting all arguments
     by default.
  • Loading branch information
pranavm-nvidia committed Oct 23, 2024
1 parent 1b02e32 commit 8923f16
Show file tree
Hide file tree
Showing 48 changed files with 506 additions and 476 deletions.
3 changes: 0 additions & 3 deletions tripy/docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ To view the documentation, you can open `build/docs/index.html` in a browser.
The `export.public_api()` decorator allows you to specify metadata for documentation
generation, such as where in the documentation hierarchy the API should be documented.

The `constraints.dtype_info()` decorator verifies the data types a function claims to support and generates
corresponding documentation. For more information, see [this guide](../tests/spec_verification/README.md).

The `generate_rsts.py` script uses this information to automatically generate a directory
structure and populate it with `.rst` files.

Expand Down
24 changes: 12 additions & 12 deletions tripy/docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@
def process_docstring(app, what, name, obj, options, lines):
doc = "\n".join(lines).strip()
blocks = helper.consolidate_code_blocks(doc)
unqual_name = name.split(".")[-1]
name = name.lstrip("tripy.")

# Check signature for functions/methods and class constructors.
if what in {"function", "method"} or (what == "class" and name in seen_classes):
Expand All @@ -177,7 +177,7 @@ def process_docstring(app, what, name, obj, options, lines):
pname = "*" + pname

# Type annotations are optional for the `self` parameter unless the API has to be type-verified.
if pname != "self" or unqual_name in TYPE_VERIFICATION:
if pname != "self" or name in TYPE_VERIFICATION:
assert (
pname in documented_args
), f"Missing documentation for parameter: '{pname}' in: '{obj}'. Please ensure you've included this in the `Args:` section. Note: Documented parameters were: {documented_args} {doc}"
Expand All @@ -204,7 +204,7 @@ def process_docstring(app, what, name, obj, options, lines):
":returns:" in doc
), f"For: {obj}, return value is not documented. Please ensure you've included a `Returns:` section"

if unqual_name in TYPE_VERIFICATION:
if name in TYPE_VERIFICATION:
add_text_index = -1
for index, block in enumerate(blocks):

Expand All @@ -215,7 +215,7 @@ def insert_block(text):
index += 1

if re.search(r".. code-block::", block):
type_dict = TYPE_VERIFICATION[unqual_name].dtypes
type_dict = TYPE_VERIFICATION[name].dtypes
insert_block("TYPE CONSTRAINTS:")
# Add the dtype constraint name and the dtypes that correlate.
for type_name, dt in type_dict.items():
Expand All @@ -234,10 +234,10 @@ def insert_block(text):
)
insert_block("\n")

if TYPE_VERIFICATION[unqual_name].dtype_exceptions:
if TYPE_VERIFICATION[name].exceptions:
# Add the dtype exceptions.
insert_block("UNSUPPORTED TYPE COMBINATIONS:")
for exception_dict in TYPE_VERIFICATION[unqual_name].dtype_exceptions:
for exception_dict in TYPE_VERIFICATION[name].exceptions:
insert_block(
" - "
+ ", ".join([f"**{key}**\ =\ :class:`{val}`" for key, val in exception_dict.items()]),
Expand All @@ -248,17 +248,17 @@ def insert_block(text):
if re.search(r":param \w+: ", block):
param_name = re.match(r":param (\w+): ", block).group(1)
# Add dtype constraint to start of each parameter description.
if TYPE_VERIFICATION[unqual_name].dtype_constraints.get(param_name, None):
if TYPE_VERIFICATION[name].constraints.get(param_name, None):
add_text_index = re.search(r":param \w+: ", block).span()[1]
blocks[index] = (
f"{block[0:add_text_index]}[*dtype=*\ **{TYPE_VERIFICATION[unqual_name].dtype_constraints[param_name]}**\ ] {block[add_text_index:]}"
f"{block[0:add_text_index]}[*dtype=*\ **{TYPE_VERIFICATION[name].constraints[param_name]}**\ ] {block[add_text_index:]}"
)

if TYPE_VERIFICATION[unqual_name].return_dtype is not None and re.search(r":returns:", block):
if TYPE_VERIFICATION[name].return_dtype is not None and re.search(r":returns:", block):
add_text_index = re.search(r":returns:", block).span()[1] + 1
# Add dtype constraint to start of returns description.
blocks[index] = (
f"{block[0:add_text_index]}[*dtype=*\ **{TYPE_VERIFICATION[unqual_name].return_dtype}**\ ] {block[add_text_index:]}"
f"{block[0:add_text_index]}[*dtype=*\ **{TYPE_VERIFICATION[name].return_dtype}**\ ] {block[add_text_index:]}"
)

seen_classes.add(name)
Expand All @@ -267,8 +267,8 @@ def allow_no_example():
# `tp.Module`s include examples in their constructors, so their __call__ methods don't require examples.
is_tripy_module_call_method = False
if what == "method" and obj.__name__ == "__call__":
class_name = name.rpartition(".")[0]
# Class names will be prefixed with tripy.<...>, so we need to import it here to make eval() work.
class_name = "tripy." + name.rpartition(".")[0]
# Class names are prefixed with tripy.<...>, so we need to import it here to make eval() work.
import tripy

is_tripy_module_call_method = issubclass(eval(class_name), tp.Module)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
# limitations under the License.
#

import tripy as tp
import inspect
from typing import ForwardRef, List, Optional, Union, get_args, get_origin

from typing import Union, Optional, get_origin, get_args, ForwardRef, List
import tripy as tp
from tripy.common import datatype
import inspect


def tensor_builder(init, dtype, namespace):
Expand All @@ -29,7 +29,10 @@ def tensor_builder(init, dtype, namespace):
return out
elif not isinstance(init, tp.Tensor):
return init
out = tp.cast(init, dtype=namespace[dtype])

out = init
if dtype is not None:
out = tp.cast(out, dtype=namespace[dtype])
out.eval()
return out

Expand Down Expand Up @@ -146,6 +149,8 @@ def default_builder(init, dtype, namespace):
"maxpool": {"input": tp.ones((1, 1, 8, 8)), "kernel_dims": [2, 2]},
"avgpool": {"input": tp.ones((1, 1, 8, 8)), "kernel_dims": [2, 2]},
"zeros": {"shape": [3, 2]},
# Methods
"Shape.as_tensor": {"self": tp.Shape([1, 2, 3])},
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,44 @@
import pytest
from tests import helper
from tests.conftest import skip_if_older_than_sm89
from tests.spec_verification.object_builders import create_obj
from tests.constraints.object_builders import create_obj

import tripy as tp
from tripy import constraints
from tripy.common.datatype import DATA_TYPES
from tripy.constraints import TYPE_VERIFICATION
from tripy.export import PUBLIC_APIS

# Get all functions/methods which have tensors in the type signature
PUBLIC_API_TENSOR_FUNCTIONS = []
PUBLIC_API_TENSOR_FUNCTION_NAMES = []
for api in PUBLIC_APIS:
is_module = False
if inspect.isfunction(api.obj):
funcs = [api.obj]
elif inspect.isclass(api.obj):
if issubclass(api.obj, tp.Module):
# Skip over modules since the dtype constraint decorator doesn't work for them yet.
continue
funcs = [val for _, val in inspect.getmembers(api.obj, predicate=inspect.isfunction)]

for func in funcs:
if "Tensor" in str(inspect.signature(func)):
PUBLIC_API_TENSOR_FUNCTIONS.append(func)
name = api.qualname
if func.__name__ not in name:
name += f".{func.__name__}"
PUBLIC_API_TENSOR_FUNCTION_NAMES.append(name)


@pytest.mark.parametrize("api", PUBLIC_API_TENSOR_FUNCTIONS, ids=PUBLIC_API_TENSOR_FUNCTION_NAMES)
def test_all_public_apis_verified(api):
NON_VERIFIABLE_APIS = {"plugin", "Executable.__call__"}
if api.__qualname__ in NON_VERIFIABLE_APIS:
pytest.skip(f"Cannot do data type verification for: {NON_VERIFIABLE_APIS}")

assert api.__qualname__ in TYPE_VERIFICATION, f"Missing datatype constraints for: {api.__qualname__}"


DTYPE_CONSTRAINT_CASES = []

Expand Down Expand Up @@ -139,21 +172,34 @@ def _run_dtype_constraints_subtest(test_data):
if func_name in SPECIAL_FUNCS:
ret_val = SPECIAL_FUNCS[func_name](*args, **kwargs)
else:
all_locals = locals()
exec(
dedent(
# We can't call `func_obj` directly because there may be other decorators
# applied after the dtype constraints one. By importing it like this, we
# get the final version of the function.
f"""
from {func_obj.__module__} import {func_obj.__qualname__}
# We can't call `func_obj` directly because there may be other decorators
# applied after the dtype constraints one. By importing it like this, we
# get the final version of the function/class.
#
# NOTE: inspect.ismethod() will not work, possibly because of our decorators.
if "." in func_obj.__qualname__:
cls, method = func_obj.__qualname__.split(".")

# For methods, the first argument will be the instance
obj = args.pop(0)

code = f"""
from {func_obj.__module__} import {cls}
ret_val = obj.{method}(*args, **kwargs)
"""
else:
code = f"""
from {func_obj.__module__} import {func_obj.__qualname__}
if {func_name} == "shape":
ret_val = args[0].shape
else:
ret_val = {func_obj.__qualname__}(*args, **kwargs)
"""
),
globals(),
all_locals,
)
"""

all_locals = locals()
exec(dedent(code), globals(), all_locals)
ret_val = all_locals["ret_val"]

# If output does not have dtype skip .eval().
Expand Down Expand Up @@ -184,3 +230,17 @@ def test_dtype_constraints(test_data):
ret_val, namespace = _run_dtype_constraints_subtest(test_data)
if isinstance(ret_val, tp.Tensor):
assert ret_val.dtype == namespace[return_dtype]


@constraints.dtypes(constraints={"tensors": "T1"}, variables={"T1": ["float32"]})
def sequence_func(tensors: List[tp.Tensor]):
return


class TestDtypes:
def test_works_with_sequences(self):
sequence_func([tp.ones((2, 2), dtype=tp.float32), tp.ones((2, 2), dtype=tp.float32)])

def test_raises_on_mismatched_sequence_dtypes(self):
with helper.raises(tp.TripyException, match="Mismatched data types in sequence argument for 'sequence_func'."):
sequence_func([tp.ones((2, 2), dtype=tp.float32), tp.ones((2, 2), dtype=tp.int32)])
22 changes: 11 additions & 11 deletions tripy/tests/frontend/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,57 +27,57 @@
# for magic methods. We would not want to see this outside of tests.


@convert_inputs_to_tensors()
@convert_inputs_to_tensors(["a"])
def __func_test_basic__(a):
return a


@convert_inputs_to_tensors()
@convert_inputs_to_tensors(["a", "b", "c"])
def __func_test_multi_input__(a, b, c):
return a, b, c


@convert_inputs_to_tensors(sync_arg_types=[("a", "b", "c")])
@convert_inputs_to_tensors(["a", "b", "c"], sync_arg_types=[("a", "b", "c")])
def __func_test_sync_arg_types__(a, b, c):
return a, b, c


@convert_inputs_to_tensors()
@convert_inputs_to_tensors(["args"])
def __func_test_variadic_positional_args__(*args):
return args


@convert_inputs_to_tensors()
@convert_inputs_to_tensors(["x", "args"])
def __func_test_arg_before_variadic_positional_args__(x, *args):
return (x,) + args


@convert_inputs_to_tensors()
@convert_inputs_to_tensors(["args", "y"])
def __func_test_kwarg_after_variadic_positional_args__(*args, y):
return args + (y,)


@convert_inputs_to_tensors(unpack_argument=["xs"])
@convert_inputs_to_tensors(["xs"], unpack_argument=["xs"])
def __func_test_convert_list_input__(xs):
return xs


@convert_inputs_to_tensors(sync_arg_types=[("xs",)], unpack_argument=["xs"])
@convert_inputs_to_tensors(["xs"], sync_arg_types=[("xs",)], unpack_argument=["xs"])
def __func_test_sync_within_list__(xs):
return xs


@convert_inputs_to_tensors(sync_arg_types=[("x", "ys")], unpack_argument=["ys"])
@convert_inputs_to_tensors(["x", "ys"], sync_arg_types=[("x", "ys")], unpack_argument=["ys"])
def __func_test_sync_single_type_to_list__(x, ys):
return x, ys


@convert_inputs_to_tensors(sync_arg_types=[("xs", "y")], unpack_argument=["xs"])
@convert_inputs_to_tensors(["xs", "y"], sync_arg_types=[("xs", "y")], unpack_argument=["xs"])
def __func_test_sync_list_type_to_single__(xs, y):
return xs, y


@convert_inputs_to_tensors(sync_arg_types=[("xs", "ys")], unpack_argument=["xs", "ys"])
@convert_inputs_to_tensors(["xs", "ys"], sync_arg_types=[("xs", "ys")], unpack_argument=["xs", "ys"])
def __func_test_sync_list_types__(xs, ys):
return xs, ys

Expand Down
2 changes: 0 additions & 2 deletions tripy/tests/spec_verification/.gitignore

This file was deleted.

11 changes: 0 additions & 11 deletions tripy/tests/spec_verification/README.md

This file was deleted.

Loading

0 comments on commit 8923f16

Please sign in to comment.