Skip to content

Commit

Permalink
Replaces convert_inputs_to_tensors decorator with a better decorator
Browse files Browse the repository at this point in the history
- Replaces `convert_inputs_to_tensors` decorator with `convert_to_tensors`,
     which will convert any `TensorLike` arguments by default and sync data types
     according to any specified data type constraints. This means that information
     does not need to be duplicated in the decorator like it used to be.

- Disallows unsafe downcasting in the new `convert_to_tensors` decorator.
     Previously, you could have:
     ```py
     x = tp.Tensor([1], dtype=tp.int32)
     x * 3.99
     ```
     which would yield a result equal to:
     ```py
     tp.Tensor([3])
     ```
     due to truncation when casting.

- Fixes several type annotations for binary elementwise ops.
  • Loading branch information
pranavm-nvidia committed Oct 25, 2024
1 parent f3fe6bd commit 7c6997d
Show file tree
Hide file tree
Showing 17 changed files with 336 additions and 564 deletions.
2 changes: 1 addition & 1 deletion tripy/docs/post0_developer_guides/how-to-add-new-ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ import tripy.frontend.utils as frontend_utils
@export.public_api(document_under="tensor_operations")

# The `convert_shape_inputs` decorator converts the specified function arguments into `tripy.Shape`s,
# which would allow for using Python numbers and sequences. The `convert_inputs_to_tensors` decorator more generally converts
# which would allow for using Python numbers and sequences. The `convert_to_tensors` decorator more generally converts
# function arguments into Tripy tensors and is also commonly used in the codebase.
@frontend_utils.convert_shape_inputs(["shape"])
def theta(shape: Tuple[int], dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "tripy.Tensor":
Expand Down
2 changes: 2 additions & 0 deletions tripy/tests/constraints/object_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import tripy as tp
from tripy.common import datatype
from tripy.types import TensorLike


def tensor_builder(init, dtype, namespace):
Expand Down Expand Up @@ -64,6 +65,7 @@ def default_builder(init, dtype, namespace):
find_func = {
"tripy.Tensor": tensor_builder,
"tripy.types.TensorLike": tensor_builder,
TensorLike: tensor_builder,
"tripy.Shape": tensor_builder,
"tripy.dtype": dtype_builder,
datatype.dtype: dtype_builder,
Expand Down
18 changes: 13 additions & 5 deletions tripy/tests/constraints/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_all_public_apis_verified(api):
exception = True
positive_case = False

ids = [f"{dtype_name}={dtype}" for dtype_name, dtype in namespace.items()]
ids = [f"{dtype_name}-{dtype}" for dtype_name, dtype in namespace.items()]

DTYPE_CONSTRAINT_CASES.append(
pytest.param(
Expand Down Expand Up @@ -160,11 +160,18 @@ def _run_dtype_constraints_subtest(test_data):
args = [kwargs["self"]]
del kwargs["self"]

def cast_to_bool(arg0, arg1):
if arg1.dtype == tp.bool:
return bool(arg0)
return arg0

SPECIAL_FUNCS = {
"__radd__": (lambda self, other: self + other),
"__rsub__": (lambda self, other: self - other),
"__rpow__": (lambda self, other: self**other),
"__rmul__": (lambda self, other: self * other),
"__add__": (lambda self, other: self + cast_to_bool(other, self)),
"__mul__": (lambda self, other: self * cast_to_bool(other, self)),
"__radd__": (lambda self, other: cast_to_bool(self, other) + other),
"__rsub__": (lambda self, other: cast_to_bool(self, other) - other),
"__rpow__": (lambda self, other: cast_to_bool(self, other) ** other),
"__rmul__": (lambda self, other: cast_to_bool(self, other) * other),
"__rtruediv__": (lambda self, other: self / other),
"shape": (lambda self: self.shape),
}
Expand Down Expand Up @@ -200,6 +207,7 @@ def _run_dtype_constraints_subtest(test_data):

all_locals = locals()
exec(dedent(code), globals(), all_locals)

ret_val = all_locals["ret_val"]

# If output does not have dtype skip .eval().
Expand Down
2 changes: 1 addition & 1 deletion tripy/tests/frontend/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def test_quantize_rejected(self, values):
def test_binary_elementwise_broadcast_rejected(self, values):
with raises(
tp.TripyException,
match=r"\_\_mul\_\_ expects tensor arguments to have matching class types, but got mixed `tp\.Tensor` and `tp\.Shape` arguments\.",
match=r"Error processing shape inputs in operator BinaryElementwise Further information: Binary elementwise operators do not accept combinations of Shape and Tensor arguments.",
):
tp.Shape(values).multiply(tp.Tensor([values, values]))

Expand Down
Loading

0 comments on commit 7c6997d

Please sign in to comment.