Skip to content

Commit

Permalink
Bug fix for double_quant
Browse files Browse the repository at this point in the history
  • Loading branch information
Xia-Weiwen committed Apr 18, 2024
1 parent 8d0b695 commit 67d8661
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 23 deletions.
2 changes: 1 addition & 1 deletion bitsandbytes/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def double_quant(
cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
):
assert_on_cpu([A, col_stats, row_stats, out_col, out_row])
return double_quant_impl(A, col_stats, row_stats, out_col, out_row)
return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)

@classmethod
def transform(cls, A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None):
Expand Down
11 changes: 9 additions & 2 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import warnings


Tensor = torch.Tensor
Expand Down Expand Up @@ -66,7 +67,7 @@ def quant_to_int8(A, stats):
if row_stats is None or col_stats is None:
row_stats, col_stats = get_row_col_stats(A)
else:
outlier_indices = torch.abs(A) > threshold # find outliers
outlier_indices = torch.abs(A) >= threshold # find outliers
outlier_coord = outlier_indices.nonzero() # get outlier coordinates
outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor
outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor
Expand All @@ -77,10 +78,13 @@ def quant_to_int8(A, stats):
if row_stats is None or col_stats is None:
A[outlier_indices] = 0 # zero out outliers
row_stats, col_stats = get_row_col_stats(A)
A[outlier_indices] = outlier_values # restore outliers for later use

quant_by_row = quant_to_int8(A, row_stats.unsqueeze(-1))
quant_by_col = quant_to_int8(A, col_stats.unsqueeze(0))

if coo_tensor is not None:
A[outlier_indices] = outlier_values # restore outliers for later use

if out_row is not None:
out_row.copy_(quant_by_row)
else:
Expand Down Expand Up @@ -189,6 +193,9 @@ def mm_dequant_impl(
if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2])

if compute_dtype not in [torch.float32, torch.bfloat16]:
warnings.warn(f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use float instead")
compute_dtype = torch.float32
A_reshaped = A.reshape(out_shape).to(compute_dtype)
row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype)
col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype)
Expand Down
19 changes: 0 additions & 19 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,19 +594,6 @@ def cpu(self):
setattr(self, "SCB", SCB)
return self

def xpu(self):
# we store the 8-bit rows-major weight
B = self.data.contiguous().half().cpu()
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
if CBt is not None:
del CBt
if SCBt is not None:
del SCBt
self.data = CB
setattr(self, "CB", CB)
setattr(self, "SCB", SCB)
return self

@overload
def to(
self: T,
Expand All @@ -626,12 +613,6 @@ def to(self, *args, **kwargs):

if device is not None and device.type == "cuda" and self.data.device.type == "cpu":
return self.cuda(device)
elif (
device is not None
and device.type == "xpu"
and self.data.dtype != torch.int8
):
return self.xpu()
elif (
device is not None
and device.type == "cpu"
Expand Down
4 changes: 3 additions & 1 deletion tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,8 @@ def test_coo_double_quant(dim1, dim2, device, dtype):
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)

if idx.sum() > 0:
assert coo_tensor is not None
if coo_tensor is not None:
A1 = A * idx
A2 = torch.zeros_like(A)
Expand All @@ -1229,7 +1231,7 @@ def test_coo_double_quant(dim1, dim2, device, dtype):

A1 = A * (idx == 0)
A2 = (CA.float() * statsA.unsqueeze(1) / 127).to(dtype)
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
torch.testing.assert_close(A1, A2, rtol=0.05, atol=1.5e-2)


@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
Expand Down

0 comments on commit 67d8661

Please sign in to comment.