From 44cbdd62c7a40a097f1c27d9b33887c5e86c914f Mon Sep 17 00:00:00 2001 From: Jhalak Patel Date: Wed, 7 Aug 2024 17:43:56 -0700 Subject: [PATCH] Accept shape tensors in Compiler --- tripy/tests/backend/test_compiler_api.py | 8 ++- tripy/tripy/backend/compiler_api.py | 64 ++++++++++++++---------- tripy/tripy/frontend/shape.py | 3 ++ 3 files changed, 47 insertions(+), 28 deletions(-) diff --git a/tripy/tests/backend/test_compiler_api.py b/tripy/tests/backend/test_compiler_api.py index f4d11d5d6..f26091a3c 100644 --- a/tripy/tests/backend/test_compiler_api.py +++ b/tripy/tests/backend/test_compiler_api.py @@ -62,6 +62,12 @@ class TestInput: ([(1, 2, 3)], (1,), (2,), (3,)), # Only one value specified ([1], (1,), (1,), (1,)), + # one dynamic and one static dim + ([(1, 2, 3), 4], (1, 4), (2, 4), (3, 4)), + # Both dim dynamic + ([(1, 2, 3), (4, 5, 6)], (1, 4), (2, 5), (3, 6)), + # static shape via shape tensor + (tp.Shape([1, 4]), (1, 4), (1, 4), (1, 4)), ], ) def test_shapes_normalized(self, shape, expected_min, expected_opt, expected_max): @@ -95,7 +101,7 @@ def test_invalid_shape(self, shape, expected_error): @pytest.fixture(scope="session") def single_return_executable(): compiler = tp.Compiler(add) - return compiler.compile(tp.InputInfo((2, 2), dtype=tp.float32), tp.InputInfo((2, 2), dtype=tp.float32)) + return compiler.compile(tp.InputInfo(tp.Shape([2, 2]), dtype=tp.float32), tp.InputInfo((2, 2), dtype=tp.float32)) @pytest.fixture(scope="session") diff --git a/tripy/tripy/backend/compiler_api.py b/tripy/tripy/backend/compiler_api.py index 219eb9a69..d42fc91b7 100644 --- a/tripy/tripy/backend/compiler_api.py +++ b/tripy/tripy/backend/compiler_api.py @@ -29,7 +29,7 @@ from tripy.backend.mlir import utils as mlir_utils from tripy.common.exception import raise_error from tripy.common.shape_bounds import ShapeBounds -from tripy.frontend import Tensor, Trace +from tripy.frontend import Tensor, Trace, Shape from tripy.utils import json as json_utils @@ -40,7 +40,9 @@ class InputInfo: """ def __init__( - self, shape: Sequence[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]], dtype: "tripy.dtype" + self, + shape: Sequence[Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int], Shape]], + dtype: "tripy.dtype", ) -> None: """ Args: @@ -67,36 +69,44 @@ def __init__( assert inp.shape_bounds.opt == (2, 4) assert inp.shape_bounds.max == (3, 4) """ - # TODO (#252): Allow `shape` to be a shape tensor min_shape = [] opt_shape = [] max_shape = [] - for elem in shape: - if isinstance(elem, numbers.Number): - elem = (elem,) * 3 - elif isinstance(elem, Sequence): - if not all(isinstance(val, numbers.Number) for val in elem): - raise_error( - "Shape values must be numbers.", - [f"Shape: {shape} contains an element: {repr(elem)} with non-numerical value(s)"], - ) - if len(elem) != 3: + if isinstance(shape, Shape): + assert shape.shape.rank == 1 + nb_dims = shape.shape.data().data()[0] + for i in range(nb_dims): + d = shape[i].data().data() + assert isinstance(d, numbers.Number) + min_shape.append(d) + opt_shape.append(d) + max_shape.append(d) + else: + for elem in shape: + if isinstance(elem, numbers.Number): + elem = (elem,) * 3 + elif isinstance(elem, Sequence): + if not all(isinstance(val, numbers.Number) for val in elem): + raise_error( + "Shape values must be numbers.", + [f"Shape: {shape} contains an element: {repr(elem)} with non-numerical value(s)"], + ) + if len(elem) != 3: + raise_error( + "Incorrect number of shape values provided.", + [ + f"Exactly 3 shape values must be provided for each dimension (min/opt/max)" + f" but got: {len(elem)} values in shape: {shape}. " + ], + ) + else: raise_error( - "Incorrect number of shape values provided.", - [ - f"Exactly 3 shape values must be provided for each dimension (min/opt/max)" - f" but got: {len(elem)} values in shape: {shape}. " - ], + "Shape values should be either a single number or a Tuple specifying min/opt/max bounds ", + [f"Shape: {shape} contains an invalid element: {elem}"], ) - else: - raise_error( - "Shape values should be either a single number or a Tuple specifying min/opt/max bounds.", - [f"Shape: {shape} contains an invalid element: {elem}"], - ) - - min_shape.append(elem[0]) - opt_shape.append(elem[1]) - max_shape.append(elem[2]) + min_shape.append(elem[0]) + opt_shape.append(elem[1]) + max_shape.append(elem[2]) self.shape_bounds = ShapeBounds(tuple(min_shape), tuple(opt_shape), tuple(max_shape)) self.dtype = dtype diff --git a/tripy/tripy/frontend/shape.py b/tripy/tripy/frontend/shape.py index cc25cf226..91fce9be8 100644 --- a/tripy/tripy/frontend/shape.py +++ b/tripy/tripy/frontend/shape.py @@ -79,6 +79,9 @@ def __init__( else: super().__init__(data=data, shape=shape, dtype=int32, name=name, device=device) + def __iter__(self): + raise TypeError("Iterating over shape tensors is not supported") + def as_tensor(self) -> Tensor: """ Return an ordinary Tripy :class:`Tensor` with the same contents as this :class:`Shape` . No copying is done.