-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use create_memref_view_from_dlpack
within Tripy
#72
Conversation
create_memref_view_from_dlpack
within Tripy
@@ -166,11 +166,11 @@ def test_weights_loading_from_torch(self, kind): | |||
|
|||
tripy_linear = tp.Linear(2, 3) | |||
if kind == "gpu": | |||
tripy_linear.weight = tp.Parameter(tp.Tensor(torch_linear.weight.to("cuda"), device=tp.device(kind))) | |||
tripy_linear.bias = tp.Parameter(tp.Tensor(torch_linear.bias.to("cuda"), device=tp.device(kind))) | |||
tripy_linear.weight = tp.Parameter(tp.Tensor(torch_linear.weight.detach().to("cuda"), device=tp.device(kind))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens when you remove detach()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Torch tensors will not allow invocation of the __dlpack__
method if they have a grad
tensor attached.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we instead handle this internally? This could lead to bad UX.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you suggesting we add a check within the Array._memref
function to see if the data
is a torch tensor and then manually calling detach()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that would be better.
This function greatly simplifies Tripy's Array implementation. We want to be able to handle memref creation from all types that implement the `__dlpack__()` interface rather than the limited set we currently support. This function should allow us to achieve this. The corresponding Tripy changes are #72.
I'm closing this PR because #95 covers all the changes. Thanks. |
Relies on #71. Simplifies our Array class by using
create_memref_view_from_dlpack
.