Skip to content

Commit

Permalink
Accept shape tensors in Compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Aug 14, 2024
1 parent 953e9b2 commit 44cbdd6
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 28 deletions.
8 changes: 7 additions & 1 deletion tripy/tests/backend/test_compiler_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ class TestInput:
([(1, 2, 3)], (1,), (2,), (3,)),
# Only one value specified
([1], (1,), (1,), (1,)),
# one dynamic and one static dim
([(1, 2, 3), 4], (1, 4), (2, 4), (3, 4)),
# Both dim dynamic
([(1, 2, 3), (4, 5, 6)], (1, 4), (2, 5), (3, 6)),
# static shape via shape tensor
(tp.Shape([1, 4]), (1, 4), (1, 4), (1, 4)),
],
)
def test_shapes_normalized(self, shape, expected_min, expected_opt, expected_max):
Expand Down Expand Up @@ -95,7 +101,7 @@ def test_invalid_shape(self, shape, expected_error):
@pytest.fixture(scope="session")
def single_return_executable():
compiler = tp.Compiler(add)
return compiler.compile(tp.InputInfo((2, 2), dtype=tp.float32), tp.InputInfo((2, 2), dtype=tp.float32))
return compiler.compile(tp.InputInfo(tp.Shape([2, 2]), dtype=tp.float32), tp.InputInfo((2, 2), dtype=tp.float32))


@pytest.fixture(scope="session")
Expand Down
64 changes: 37 additions & 27 deletions tripy/tripy/backend/compiler_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from tripy.backend.mlir import utils as mlir_utils
from tripy.common.exception import raise_error
from tripy.common.shape_bounds import ShapeBounds
from tripy.frontend import Tensor, Trace
from tripy.frontend import Tensor, Trace, Shape
from tripy.utils import json as json_utils


Expand All @@ -40,7 +40,9 @@ class InputInfo:
"""

def __init__(
self, shape: Sequence[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]], dtype: "tripy.dtype"
self,
shape: Sequence[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int], Shape]],
dtype: "tripy.dtype",
) -> None:
"""
Args:
Expand All @@ -67,36 +69,44 @@ def __init__(
assert inp.shape_bounds.opt == (2, 4)
assert inp.shape_bounds.max == (3, 4)
"""
# TODO (#252): Allow `shape` to be a shape tensor
min_shape = []
opt_shape = []
max_shape = []
for elem in shape:
if isinstance(elem, numbers.Number):
elem = (elem,) * 3
elif isinstance(elem, Sequence):
if not all(isinstance(val, numbers.Number) for val in elem):
raise_error(
"Shape values must be numbers.",
[f"Shape: {shape} contains an element: {repr(elem)} with non-numerical value(s)"],
)
if len(elem) != 3:
if isinstance(shape, Shape):
assert shape.shape.rank == 1
nb_dims = shape.shape.data().data()[0]
for i in range(nb_dims):
d = shape[i].data().data()
assert isinstance(d, numbers.Number)
min_shape.append(d)
opt_shape.append(d)
max_shape.append(d)
else:
for elem in shape:
if isinstance(elem, numbers.Number):
elem = (elem,) * 3
elif isinstance(elem, Sequence):
if not all(isinstance(val, numbers.Number) for val in elem):
raise_error(
"Shape values must be numbers.",
[f"Shape: {shape} contains an element: {repr(elem)} with non-numerical value(s)"],
)
if len(elem) != 3:
raise_error(
"Incorrect number of shape values provided.",
[
f"Exactly 3 shape values must be provided for each dimension (min/opt/max)"
f" but got: {len(elem)} values in shape: {shape}. "
],
)
else:
raise_error(
"Incorrect number of shape values provided.",
[
f"Exactly 3 shape values must be provided for each dimension (min/opt/max)"
f" but got: {len(elem)} values in shape: {shape}. "
],
"Shape values should be either a single number or a Tuple specifying min/opt/max bounds ",
[f"Shape: {shape} contains an invalid element: {elem}"],
)
else:
raise_error(
"Shape values should be either a single number or a Tuple specifying min/opt/max bounds.",
[f"Shape: {shape} contains an invalid element: {elem}"],
)

min_shape.append(elem[0])
opt_shape.append(elem[1])
max_shape.append(elem[2])
min_shape.append(elem[0])
opt_shape.append(elem[1])
max_shape.append(elem[2])

self.shape_bounds = ShapeBounds(tuple(min_shape), tuple(opt_shape), tuple(max_shape))
self.dtype = dtype
Expand Down
3 changes: 3 additions & 0 deletions tripy/tripy/frontend/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def __init__(
else:
super().__init__(data=data, shape=shape, dtype=int32, name=name, device=device)

def __iter__(self):
raise TypeError("Iterating over shape tensors is not supported")

def as_tensor(self) -> Tensor:
"""
Return an ordinary Tripy :class:`Tensor` with the same contents as this :class:`Shape` . No copying is done.
Expand Down

0 comments on commit 44cbdd6

Please sign in to comment.