Skip to content

Commit

Permalink
fix 4bit tensor shape
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Jan 20, 2025
1 parent 629af94 commit b7ca20b
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def quantize_4bit_impl(
quant_type=quant_type,
)

return out.unsqueeze(0), state
return out.reshape(-1, 1), state


def dequant_8bit(A, offset, quant_state):
Expand Down Expand Up @@ -449,12 +449,8 @@ def dequantize_4bit_impl(
torch.Tensor:
Dequantized tensor.
"""
if A.shape[0] == 1:
transpose = False
A = A.squeeze(0)
elif A.shape[1] == 1:
transpose = True
A = A.squeeze(1)
transpose = True if A.shape[0] == 1 else False
A = A.reshape(-1)

if quant_state is None:
assert absmax is not None and out is not None
Expand Down

0 comments on commit b7ca20b

Please sign in to comment.