Skip to content

Commit

Permalink
Added pre and post device call to transform.
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Aug 4, 2022
1 parent 320eacb commit 6101a8f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
ptrA = get_ptr(A)
ptrOut = get_ptr(out)
is_on_gpu([A, out])
prev_device = pre_call(A.device)
if to_order == 'col32':
if transpose:
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
Expand All @@ -1236,8 +1237,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
else:
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')


post_call(prev_device)


return out, new_state
Expand Down

0 comments on commit 6101a8f

Please sign in to comment.