Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use create_memref_view_from_dlpack within Tripy #72

Closed
wants to merge 3 commits into from

Conversation

markkraay
Copy link
Collaborator

Relies on #71. Simplifies our Array class by using create_memref_view_from_dlpack.

@markkraay markkraay added enhancement New feature or request tripy Pull request for the tripy project labels Aug 8, 2024
@markkraay markkraay self-assigned this Aug 8, 2024
@github-actions github-actions bot added the mlir-tensorrt Pull request for the mlir-tensorrt project label Aug 8, 2024
@markkraay markkraay changed the title Use create memref view within Tripy Use create_memref_view_from_dlpack within Tripy Aug 8, 2024
@pranavm-nvidia pranavm-nvidia removed the mlir-tensorrt Pull request for the mlir-tensorrt project label Aug 8, 2024
@@ -166,11 +166,11 @@ def test_weights_loading_from_torch(self, kind):

tripy_linear = tp.Linear(2, 3)
if kind == "gpu":
tripy_linear.weight = tp.Parameter(tp.Tensor(torch_linear.weight.to("cuda"), device=tp.device(kind)))
tripy_linear.bias = tp.Parameter(tp.Tensor(torch_linear.bias.to("cuda"), device=tp.device(kind)))
tripy_linear.weight = tp.Parameter(tp.Tensor(torch_linear.weight.detach().to("cuda"), device=tp.device(kind)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when you remove detach()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Torch tensors will not allow invocation of the __dlpack__ method if they have a grad tensor attached.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we instead handle this internally? This could lead to bad UX.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting we add a check within the Array._memref function to see if the data is a torch tensor and then manually calling detach()?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that would be better.

markkraay added a commit that referenced this pull request Aug 16, 2024
This function greatly simplifies Tripy's Array implementation. We want
to be able to handle memref creation from all types that implement the
`__dlpack__()` interface rather than the limited set we currently
support. This function should allow us to achieve this. The
corresponding Tripy changes are #72.
@yizhuoz004
Copy link
Collaborator

I'm closing this PR because #95 covers all the changes. Thanks.

@yizhuoz004 yizhuoz004 closed this Aug 23, 2024
@yizhuoz004 yizhuoz004 deleted the dev-mkraay-use-create-memref-view branch September 9, 2024 18:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request tripy Pull request for the tripy project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants