diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py index 82c411166..5183e6485 100644 --- a/bitsandbytes/backends/cpu.py +++ b/bitsandbytes/backends/cpu.py @@ -31,7 +31,7 @@ def double_quant( cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 ): assert_on_cpu([A, col_stats, row_stats, out_col, out_row]) - return double_quant_impl(A, col_stats, row_stats, out_col, out_row) + return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold) @classmethod def transform(cls, A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None): diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 5be83d8e3..7c7927c88 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -1,4 +1,5 @@ import torch +import warnings Tensor = torch.Tensor @@ -66,7 +67,7 @@ def quant_to_int8(A, stats): if row_stats is None or col_stats is None: row_stats, col_stats = get_row_col_stats(A) else: - outlier_indices = torch.abs(A) > threshold # find outliers + outlier_indices = torch.abs(A) >= threshold # find outliers outlier_coord = outlier_indices.nonzero() # get outlier coordinates outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor @@ -77,10 +78,13 @@ def quant_to_int8(A, stats): if row_stats is None or col_stats is None: A[outlier_indices] = 0 # zero out outliers row_stats, col_stats = get_row_col_stats(A) - A[outlier_indices] = outlier_values # restore outliers for later use quant_by_row = quant_to_int8(A, row_stats.unsqueeze(-1)) quant_by_col = quant_to_int8(A, col_stats.unsqueeze(0)) + + if coo_tensor is not None: + A[outlier_indices] = outlier_values # restore outliers for later use + if out_row is not None: out_row.copy_(quant_by_row) else: @@ -189,6 +193,9 @@ def mm_dequant_impl( if len(out_shape) == 3: out_shape = (out_shape[0] * out_shape[1], out_shape[2]) + if compute_dtype not in [torch.float32, torch.bfloat16]: + warnings.warn(f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use float instead") + compute_dtype = torch.float32 A_reshaped = A.reshape(out_shape).to(compute_dtype) row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index bcba8b3d2..c0efafec0 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -594,19 +594,6 @@ def cpu(self): setattr(self, "SCB", SCB) return self - def xpu(self): - # we store the 8-bit rows-major weight - B = self.data.contiguous().half().cpu() - CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) - if CBt is not None: - del CBt - if SCBt is not None: - del SCBt - self.data = CB - setattr(self, "CB", CB) - setattr(self, "SCB", SCB) - return self - @overload def to( self: T, @@ -626,12 +613,6 @@ def to(self, *args, **kwargs): if device is not None and device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) - elif ( - device is not None - and device.type == "xpu" - and self.data.dtype != torch.int8 - ): - return self.xpu() elif ( device is not None and device.type == "cpu" diff --git a/tests/test_functional.py b/tests/test_functional.py index cf4088c00..ba1e32d77 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1221,6 +1221,8 @@ def test_coo_double_quant(dim1, dim2, device, dtype): CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + if idx.sum() > 0: + assert coo_tensor is not None if coo_tensor is not None: A1 = A * idx A2 = torch.zeros_like(A) @@ -1229,7 +1231,7 @@ def test_coo_double_quant(dim1, dim2, device, dtype): A1 = A * (idx == 0) A2 = (CA.float() * statsA.unsqueeze(1) / 127).to(dtype) - torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + torch.testing.assert_close(A1, A2, rtol=0.05, atol=1.5e-2) @pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))