From e1b60d3093759bca952b9aefea4ba1140c1e2340 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 25 Apr 2024 23:51:57 -0700 Subject: [PATCH] Fix backward --- bitsandbytes/autograd/_functions.py | 3 +++ bitsandbytes/functional.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 08d2d9fa6..7d570f28b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -94,6 +94,9 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) - :param tile_indices: reverse transformation indices, from get_inverse_transform_indices :return: contiguous row-major tensor """ + # CPU has no change on layout + if permuted_tensor.device.type == "cpu": + return permuted_tensor (rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles" tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t() diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9828add30..8fd62fd04 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1946,7 +1946,7 @@ def __init__(self, rows, cols, nnz, rowidx, colidx, values): assert rowidx.dtype == torch.int32 assert colidx.dtype == torch.int32 if values.device == torch.device("cpu"): - assert values.dtype in [torch.bfloat16, torch.float] + assert values.dtype in [torch.bfloat16, torch.half, torch.float] else: assert values.dtype == torch.float16 assert values.numel() == nnz