We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 629af94 commit b7ca20bCopy full SHA for b7ca20b
bitsandbytes/backends/cpu_xpu_common.py
@@ -397,7 +397,7 @@ def quantize_4bit_impl(
397
quant_type=quant_type,
398
)
399
400
- return out.unsqueeze(0), state
+ return out.reshape(-1, 1), state
401
402
403
def dequant_8bit(A, offset, quant_state):
@@ -449,12 +449,8 @@ def dequantize_4bit_impl(
449
torch.Tensor:
450
Dequantized tensor.
451
"""
452
- if A.shape[0] == 1:
453
- transpose = False
454
- A = A.squeeze(0)
455
- elif A.shape[1] == 1:
456
- transpose = True
457
- A = A.squeeze(1)
+ transpose = True if A.shape[0] == 1 else False
+ A = A.reshape(-1)
458
459
if quant_state is None:
460
assert absmax is not None and out is not None
0 commit comments