Skip to content

Commit b7ca20b

Browse files
committed
fix 4bit tensor shape
Signed-off-by: jiqing-feng <[email protected]>
1 parent 629af94 commit b7ca20b

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def quantize_4bit_impl(
397397
quant_type=quant_type,
398398
)
399399

400-
return out.unsqueeze(0), state
400+
return out.reshape(-1, 1), state
401401

402402

403403
def dequant_8bit(A, offset, quant_state):
@@ -449,12 +449,8 @@ def dequantize_4bit_impl(
449449
torch.Tensor:
450450
Dequantized tensor.
451451
"""
452-
if A.shape[0] == 1:
453-
transpose = False
454-
A = A.squeeze(0)
455-
elif A.shape[1] == 1:
456-
transpose = True
457-
A = A.squeeze(1)
452+
transpose = True if A.shape[0] == 1 else False
453+
A = A.reshape(-1)
458454

459455
if quant_state is None:
460456
assert absmax is not None and out is not None

0 commit comments

Comments
 (0)