-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accept shape tensors in Compiler #64
Conversation
# Both dim dynamic | ||
([(1, 2, 3), (4, 5, 6)], (1, 4), (2, 5), (3, 6)), | ||
# min/opt/max specified as shape tensor | ||
([tp.Shape([1, 2, 3])], (1,), (2,), (3,)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pranavm-nvidia this is how you expected shape profile to be provided using shape tensor?
tripy/tripy/backend/compiler_api.py
Outdated
], | ||
) | ||
elif isinstance(elem, Shape): | ||
elem = elem.data().data() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is suboptimal I guess. We should be able to slice shape tensors and populate shape bounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is elem
a shape tensor? What you want is the shape itself be a shape tensor describing the shape of input.
a = tp.ones((2,3))
shape_of_a = a.shape
a_info = tp.InputInfo(shape_of_a, dtype=tp.float32)
Your current approach is doing:
shape_of_a = a.shape
a_info = tp.InputInfo([shape_of_a[0], shape_of_a[1]], dtype=tp.float32)
|
23bf6d6
to
f7bc564
Compare
f7bc564
to
44cbdd6
Compare
) | ||
if len(elem) != 3: | ||
if isinstance(shape, Shape): | ||
assert shape.shape.rank == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this shape.shape.rank
and not shape.rank
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this. I think shape.rank == 1
by construction, here I wanted to check len(shape.shape) == 1
. #92 should help with this check once merged.
([(1, 2, 3), 4], (1, 4), (2, 4), (3, 4)), | ||
# Both dim dynamic | ||
([(1, 2, 3), (4, 5, 6)], (1, 4), (2, 5), (3, 6)), | ||
# static shape via shape tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if this is useful since we can not encode a dynamic dim here unless we allow something like: tp.Shape([(1, 2, 3), 4])
where (1, 2, 3) are min/opt/max.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@parthchadha If it does not make sense, I can close the PR.
@pranavm-nvidia mentioned that he filed the issue to add support for shape tensor inputs with value bounds but added the TODO (#252) incorrectly in InputInfo
.
) | ||
if len(elem) != 3: | ||
if isinstance(shape, Shape): | ||
assert shape.shape.rank == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this. I think shape.rank == 1
by construction, here I wanted to check len(shape.shape) == 1
. #92 should help with this check once merged.
@@ -95,7 +101,7 @@ def test_invalid_shape(self, shape, expected_error): | |||
@pytest.fixture(scope="session") | |||
def single_return_executable(): | |||
compiler = tp.Compiler(add) | |||
return compiler.compile(tp.InputInfo((2, 2), dtype=tp.float32), tp.InputInfo((2, 2), dtype=tp.float32)) | |||
return compiler.compile(tp.InputInfo(tp.Shape([2, 2]), dtype=tp.float32), tp.InputInfo((2, 2), dtype=tp.float32)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea was to support shape tensor inputs at runtime. I don't think we need to change anything about the compiler. InputInfo
can be used to express the range of values the shape tensor can take.
[f"Shape: {shape} contains an element: {repr(elem)} with non-numerical value(s)"], | ||
) | ||
if len(elem) != 3: | ||
if isinstance(shape, Shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there's any value in allowing shape tensors to compile
. InputInfo
already expresses everything we want AFAICT.
Closing this PR based on the above discussion. |
No description provided.