-
Notifications
You must be signed in to change notification settings - Fork 368
addresses the case when shape of upsample tensor contains ITensor #3841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I think the functionality of to_trt_shape_tensor is probably available at couple of places (eg: concat iirc) manually. Do you think we could unify all this ?
|
yeah thats true. cat does it for inputs while the above is for shape tensor. Yeah I guess we could unify this. Should |
9403b0f to
3d3a8ee
Compare
| # promote remaining ints to TRT consts before concat | ||
| for i, t in enumerate(trt_tensors): | ||
| if isinstance(t, int): | ||
| const = ctx.net.add_constant((1,), np.array([t], dtype=np.int32)) | ||
| set_layer_name(const, target, f"{name}_static_{i}_const") | ||
| trt_tensors[i] = const.get_output(0) | ||
|
|
||
| concat = ctx.net.add_concatenation(trt_tensors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If trt_tensors have a mix of scalar integers and ITensors of dtype int64, would this work (because you're casting the scalar integers to int32 explicitly) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case of shape tensors int will always be int32, so in that case this should work.
Coming to cat case. concat tensors will be either torch.Tensor or TRTTensor. They cannot be int. So I think the above should cover all the cases. Can you think of any other case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So my thought is how are we ensuring all trt_tensors have same datatypes explicitly before concatenating here because that will error out ?
This check could either be an assertion check or explicit type promotion of tensors within trt_tensor
|
Embedding bag looks like is failing. Need to look into |
1771071 to
91a2519
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@apbose I have 2 minor questions. Rest LGTM
| elif isinstance(cast_dtype, np.dtype): | ||
| final_dtype = _enums.dtype._from(cast_dtype).to(trt.DataType) | ||
| else: | ||
| final_dtype = cast_dtype # already trt.DataType |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we also check torch.dtype case ?
| # optional cast | ||
| if cast_dtype and isinstance(t, TRTTensor): | ||
| t = cast_trt_tensor(ctx, t, cast_dtype, f"{name}_cast_{i}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this necessary if we are also casting at line 69 onwards ?
c8a070c to
1b1dfed
Compare
Addresses #3783