Skip to content

Commit

Permalink
Fix backward
Browse files Browse the repository at this point in the history
  • Loading branch information
Xia-Weiwen committed Apr 26, 2024
1 parent 93e04b5 commit e1b60d3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
:return: contiguous row-major tensor
"""
# CPU has no change on layout
if permuted_tensor.device.type == "cpu":
return permuted_tensor
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,7 +1946,7 @@ def __init__(self, rows, cols, nnz, rowidx, colidx, values):
assert rowidx.dtype == torch.int32
assert colidx.dtype == torch.int32
if values.device == torch.device("cpu"):
assert values.dtype in [torch.bfloat16, torch.float]
assert values.dtype in [torch.bfloat16, torch.half, torch.float]
else:
assert values.dtype == torch.float16
assert values.numel() == nnz
Expand Down

0 comments on commit e1b60d3

Please sign in to comment.