From 13ad630ccce253ea805b70dd712000787f5b9f4f Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 11 Apr 2024 05:20:56 -0700 Subject: [PATCH 01/12] Add int8 ops for Intel CPU & XPU --- bitsandbytes/__init__.py | 12 +- bitsandbytes/autograd/_functions.py | 13 +- bitsandbytes/backends/cpu.py | 287 +++++++++++++++++++++ bitsandbytes/backends/xpu.py | 118 +++++++++ bitsandbytes/functional.py | 5 +- bitsandbytes/nn/modules.py | 38 +++ examples/int8_inference_huggingface_cpu.py | 32 +++ 7 files changed, 499 insertions(+), 6 deletions(-) create mode 100644 bitsandbytes/backends/cpu.py create mode 100644 bitsandbytes/backends/xpu.py create mode 100644 examples/int8_inference_huggingface_cpu.py diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 019a4f6ab..0dae37e8d 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import torch from . import research, utils from .autograd._functions import ( MatmulLtState, @@ -14,13 +15,22 @@ ) from .cextension import lib from .nn import modules +from .backends import register_backend if lib and lib.compiled_with_cuda: - from .backends import register_backend from .backends.cuda import CUDABackend from .optim import adam register_backend("cuda", CUDABackend()) + +elif torch.xpu.is_available(): + from .backends.xpu import XPUBackend + register_backend("xpu", XPUBackend) + +else: + from .backends.cpu import CPUBackend + register_backend("cpu", CPUBackend) + __pdoc__ = { "libbitsandbytes": False, "optim.optimizer.Optimizer8bit": False, diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e9821cd36..67b8b6b87 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -217,6 +217,8 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: """check if this device supports the optimized int8 kernel""" + if device == torch.device('cpu'): + return True if torch.cuda.get_device_capability(device=device) < (7, 5): return False device_name = torch.cuda.get_device_name(device=device) @@ -312,13 +314,16 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): state.outlier_pool = GlobalOutlierPooler.get_instance() # Cast A to fp16 - if A.dtype != torch.float16: - warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") + A_dtype = torch.float16 + if A.device == torch.device('cpu'): + A_dtype = torch.bfloat16 + if A.dtype != A_dtype: + warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization") # 1. Quantize A if len(A.shape) == 3: A = A.reshape(-1, A.shape[-1]) - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(A_dtype), threshold=state.threshold) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -393,7 +398,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): if using_igemmlt: C32A, SA = F.transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) - if bias is None or bias.dtype == torch.float16: + if bias is None or bias.dtype == A_dtype: # we apply the fused bias here output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) output = output.to(A.dtype) diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py new file mode 100644 index 000000000..31bb52945 --- /dev/null +++ b/bitsandbytes/backends/cpu.py @@ -0,0 +1,287 @@ +import torch + + +Tensor = torch.Tensor + + +def assert_on_cpu(tensors): + on_cpu = True + for t in tensors: + if t is None: continue # NULL pointers are fine + on_cpu &= (t.device.type == 'cpu') + if not on_cpu: + raise TypeError( + 'All input tensors need to be on CPU, but found some tensors to not be on CPU:\n' \ + f' {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}' + ) + return on_cpu + + +@torch.compile(dynamic=True, options={"fx_graph_cache": True}) +def double_quant_common( + A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 +): + """ + Find absolute max valus of each row/column of a tensor, and symmetrically quantize it to int8. + If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in + the original tensor and they are kept in COO format: (rows, cols, valus) + If threashold == 0.0, there are no outliers. + Args: + A The tensor to be analyzed and quantized. + col_stats Absolute max values of each column of A. If it is not None, use the values directly. + Otherwise, find the values. + row_stats Absolute max values of each row of A. If it is not None, use the values directly. + Otherwise, find the values. + out_col Output buffer for the result quantized per column if it is not None + out_row Output buffer for the result quantized per row if it is not None + threshold The threshold for finding outliers if it is > 0.0. Otherwise it has no effect. + Return: + A tuple of output quantized per row, output quantized per column, absolute max values of + each row of A, absolute max values of each column of A, outliers in COO format + """ + from ..functional import COOSparseTensor + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + assert A.dim() == 2, f"double_quant: Input tensor should be 2d or 3d but got {A.dim()}d" + rows = A.shape[0] + A = A.reshape(rows, cols) + + coo_tensor = None + + def get_row_col_stats(A): + row_stats = torch.max(torch.abs(A), 1).values # absolute max of each row + col_stats = torch.max(torch.abs(A), 0).values # absolute max of each col + return row_stats, col_stats + + def quant_to_int8(A, stats): + return torch.clamp(torch.round(A / stats * 127).to(torch.int8), -128, 127) + + if threshold == 0.0: + if row_stats is None or col_stats is None: + row_stats, col_stats = get_row_col_stats(A) + else: + outlier_indices = torch.abs(A) > threshold # find outliers + outlier_coord = outlier_indices.nonzero() # get outlier coordinates + outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor + outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor + outlier_values = A[outlier_indices] # outlier values for COO sparse tensor + coo_tensor = COOSparseTensor( + A.shape[0], A.shape[1], outlier_values.numel(), outlier_rows.int(), outlier_cols.int(), outlier_values + ) + if row_stats is None or col_stats is None: + A[outlier_indices] = 0 # zero out outliers + row_stats, col_stats = get_row_col_stats(A) + A[outlier_indices] = outlier_values # restore outliers for later use + + quant_by_row = quant_to_int8(A, row_stats.unsqueeze(-1)) + quant_by_col = quant_to_int8(A, col_stats.unsqueeze(0)) + if out_row is not None: + out_row.copy_(quant_by_row) + else: + out_row = quant_by_row + if out_col is not None: + out_col.copy_(quant_by_col) + else: + out_col = quant_by_col + return out_row, out_col, row_stats, col_stats, coo_tensor + + +def igemmlt_common( + A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32 +): + """ + Do GEMMM computation. Data type: int8 * int8 -> int32. + Args: + A Activation of linear, data type is int8 + B Weight of linear, data type is int8 + SA Not used for CPU/XPU + SB Not used for CPU/XPU + out Specified output tensor if it is not None + Sout Not used for CPU/XPU but returned as is + dtype Data type of output + Return: + A tuple of GEMM result in dtype and Sout + """ + assert A.dtype == torch.int8 + assert B.dtype == torch.int8 + if out is not None: + assert out.dtype == dtype + + dimsA = A.ndim + dimsB = B.ndim + shapeA = A.shape + shapeB = B.shape + assert dimsA in [2, 3], 'Only two or three dimensional matrices are supported for argument A' + assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' + + if dimsA == 2: + m = shapeA[0] + elif dimsA == 3: + m = shapeA[0] * shapeA[1] + n = shapeB[0] + k = shapeA[-1] + assert shapeA[-1] == shapeB[-1], f'Shapes of A and B do not match, got {shapeA} and {shapeB}' + shapeOut = (shapeA[0], shapeA[1], n) if dimsA == 3 else (m, n) + + # if the tensor is empty, return a transformed empty tensor with the right dimensions + if shapeA[0] == 0 and dimsA == 2: + return torch.empty((0, n), device=A.device, dtype=A.dtype) + elif shapeA[1] == 0 and dimsA == 3: + return torch.empty(tuple(shapeA[:2] + [n]), device=A.device, dtype=A.dtype) + + A_reshaped = A.reshape(m, k) + + if assert_on_cpu([A_reshaped, B]): + C = torch._int_mm(A_reshaped, B.T).to(dtype) + else: + C = torch.nn.functional.linear(A_reshaped, B).to(dtype) + if C.ndim != dimsA: + C = C.reshape(shapeOut) + if out is not None: + out.copy_(C) + else: + out = C + + return out, Sout + + +@torch.compile(dynamic=True, options={"fx_graph_cache": True}) +def mm_dequant_common( + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None, + compute_dtype=torch.float32, + output_dtype=torch.float32 +): + """ + Dequant and add bias + out = A_int32 * (scale_A, scale_B) / 127 * 127 + bias + Args: + A The output of int8 gemm, whose dtype is int32 + quant_state Not used for CPU + row_stats Absolute max value of each row of input (A) of gemm + col_stats Absolute max value of each row of weight (B) of gemm + out Output buffer + new_row_stats Not used for CPU/XPU + new_col_stats Not used for CPU/XPU + bias Bias of linear + compute_dtype Data type for computation + output_dtype Data type for output + Return: + The result + """ + assert A.dtype == torch.int32 + out_shape = A.shape + if len(out_shape) == 3: + out_shape = (out_shape[0] * out_shape[1], out_shape[2]) + + A_reshaped = A.reshape(out_shape).to(compute_dtype) + row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) + col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) + out = A_reshaped * row_stats * col_stats / (127 * 127) + if bias is not None: + out = out + bias.to(compute_dtype) + out = out.to(output_dtype) + return out + + +class CPUBackend: + mm_dequant_compute_dtype = torch.bfloat16 + mm_dequant_output_dtype = torch.bfloat16 + + @classmethod + def double_quant( + cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 + ): + assert_on_cpu([A, col_stats, row_stats, out_col, out_row]) + return double_quant_common(A, col_stats, row_stats, out_col, out_row) + + @classmethod + def transform(cls, A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None): + """ + Transform tensor A to to_order. It is originally designed for CUDA. + For CPU, it returns the original tensor if transpose=False. + Otherwise, it returns the transpose of A + """ + assert_on_cpu([A, out]) + if transpose: + if out is not None: + out.copy_(A.T) + else: + out = A.T + else: + if out is not None: + out.copy_(A) + else: + out = A + return out, state + + @classmethod + def igemmlt(cls, A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32): + assert_on_cpu([A, B]) + return igemmlt_common(A, B, SA, SB, out, Sout, dtype) + + @classmethod + def mm_dequant( + cls, + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None + ): + assert_on_cpu([A, row_stats, col_stats, out, bias]) + return mm_dequant_common( + A, + quant_state, + row_stats, + col_stats, + out, + new_row_stats, + new_col_stats, + bias, + cls.mm_dequant_compute_dtype, + cls.mm_dequant_output_dtype + ) + + @classmethod + def extract_outliers(cls, A, SA, idx): + """ + Extract columns of A by idx + """ + assert_on_cpu([A]) + return A[:, idx].contiguous() + + @classmethod + def quantize_4bit( + cls, + A: Tensor, + absmax: Tensor = None, + out: Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="fp4", + ) -> Tensor: + assert False, "quantize_4bit not yet implemented for CPU backend" + + @classmethod + def dequantize_4bit( + cls, + A: Tensor, + quant_state = None, + absmax: Tensor = None, + out: Tensor = None, + blocksize: int = 64, + quant_type="fp4", + ) -> Tensor: + assert False, "dequantize_4bit not yet implemented for CPU backend" diff --git a/bitsandbytes/backends/xpu.py b/bitsandbytes/backends/xpu.py new file mode 100644 index 000000000..9ee8a09dc --- /dev/null +++ b/bitsandbytes/backends/xpu.py @@ -0,0 +1,118 @@ +# For Intel GPU (xpu is the device name for Intel GPU in PyTorch) +import torch +from .cpu import ( + double_quant_common, + igemmlt_common, + mm_dequant_common, +) + +Tensor = torch.Tensor + +def assert_on_xpu(tensors): + on_xpu = True + for t in tensors: + if t is None: continue # NULL pointers are fine + on_xpu &= (t.device.type == 'xpu') + if not on_xpu: + raise TypeError( + 'All input tensors need to be on XPU, but found some tensors to not be on XPU:\n' \ + f' {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}' + ) + return on_xpu + + +class XPUBackend: + mm_dequant_compute_dtype = torch.half + mm_dequant_output_dtype = torch.half + + @classmethod + @torch.compile(dynamic=True, options={"fx_graph_cache": True}) + def double_quant( + cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 + ): + assert_on_xpu([A, col_stats, row_stats, out_col, out_row]) + return double_quant_common(A, col_stats, row_stats, out_col, out_row) + + @classmethod + def transform(cls, A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None): + """ + Transform tensor A to to_order. It is originally designed for CUDA. + For XPU, it returns the original tensor if transpose=False. + Otherwise, it returns the transpose of A + """ + assert_on_xpu([A, out]) + if transpose: + if out is not None: + out.copy_(A.T) + else: + out = A.T + else: + if out is not None: + out.copy_(A) + else: + out = A + return out, state + + @classmethod + def igemmlt(cls, A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32): + assert_on_xpu([A, B]) + return igemmlt_common(A, B, SA, SB, out, Sout, dtype) + + @classmethod + @torch.compile(dynamic=True, options={"fx_graph_cache": True}) + def mm_dequant( + cls, + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None + ): + assert_on_xpu([A, row_stats, col_stats, out, bias]) + return mm_dequant_common( + A, + quant_state, + row_stats, + col_stats, + out, + new_row_stats, + new_col_stats, + bias, + cls.mm_dequant_compute_dtype, + cls.mm_dequant_output_dtype + ) + + @classmethod + def extract_outliers(cls, A, SA, idx): + """ + Extract columns of A by idx + """ + assert_on_xpu([A]) + return A[:, idx].contiguous() + + @classmethod + def quantize_4bit( + cls, + A: Tensor, + absmax: Tensor = None, + out: Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="fp4", + ) -> Tensor: + assert False, "quantize_4bit not yet implemented for XPU backend" + + @classmethod + def dequantize_4bit( + cls, + A: Tensor, + quant_state = None, + absmax: Tensor = None, + out: Tensor = None, + blocksize: int = 64, + quant_type="fp4", + ) -> Tensor: + assert False, "dequantize_4bit not yet implemented for XPU backend" diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6bb02944d..baba76963 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1945,7 +1945,10 @@ class COOSparseTensor: def __init__(self, rows, cols, nnz, rowidx, colidx, values): assert rowidx.dtype == torch.int32 assert colidx.dtype == torch.int32 - assert values.dtype == torch.float16 + if values.device == torch.device('cpu'): + assert values.dtype in [torch.bfloat16, torch.float] + else: + assert values.dtype == torch.float16 assert values.numel() == nnz assert rowidx.numel() == nnz assert colidx.numel() == nnz diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ec14e5940..bcba8b3d2 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -581,6 +581,32 @@ def cuda(self, device): return self + def cpu(self): + # we store the 8-bit rows-major weight + B = self.data.contiguous().bfloat16().cpu() + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + if CBt is not None: + del CBt + if SCBt is not None: + del SCBt + self.data = CB + setattr(self, "CB", CB) + setattr(self, "SCB", SCB) + return self + + def xpu(self): + # we store the 8-bit rows-major weight + B = self.data.contiguous().half().cpu() + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + if CBt is not None: + del CBt + if SCBt is not None: + del SCBt + self.data = CB + setattr(self, "CB", CB) + setattr(self, "SCB", SCB) + return self + @overload def to( self: T, @@ -600,6 +626,18 @@ def to(self, *args, **kwargs): if device is not None and device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) + elif ( + device is not None + and device.type == "xpu" + and self.data.dtype != torch.int8 + ): + return self.xpu() + elif ( + device is not None + and device.type == "cpu" + and self.data.dtype != torch.int8 + ): + return self.cpu() else: new_param = Int8Params( super().to(device=device, dtype=dtype, non_blocking=non_blocking), diff --git a/examples/int8_inference_huggingface_cpu.py b/examples/int8_inference_huggingface_cpu.py new file mode 100644 index 000000000..b41605893 --- /dev/null +++ b/examples/int8_inference_huggingface_cpu.py @@ -0,0 +1,32 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +import time + +MAX_NEW_TOKENS = 64 +model_id = "facebook/opt-1.3b" + +text = 'Hamburg is in which country?\n' +tokenizer = AutoTokenizer.from_pretrained(model_id) +input_ids = tokenizer(text, return_tensors="pt").input_ids + +print('Loading model {}...'.format(model_id)) +quantization_config = BitsAndBytesConfig(load_in_8bit=True) +model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map='auto', + quantization_config=quantization_config, + torch_dtype=torch.bfloat16 +) +print('model dtype = {}'.format(model.dtype)) + +with torch.no_grad(): + t0 = time.time() + generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) + latency = time.time() - t0 + result = "| latency: " + str(round(latency * 1000, 3)) + " ms |" + print('+' + '-' * (len(result) - 2) + '+') + print(result) + print('+' + '-' * (len(result) - 2) + '+') + +output = tokenizer.decode(generated_ids[0], skip_special_tokens=True) +print(f"output: {output}") From 77be40bda0ee724b1d734f107534f04846e64e8a Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 15 Apr 2024 15:48:20 +0800 Subject: [PATCH 02/12] Remove XPU code; remove cpu example; add UT --- bitsandbytes/__init__.py | 11 +- bitsandbytes/backends/cpu.py | 186 +------------------ bitsandbytes/backends/cpu_xpu_common.py | 203 +++++++++++++++++++++ bitsandbytes/backends/xpu.py | 118 ------------ bitsandbytes/functional.py | 2 +- examples/int8_inference_huggingface_cpu.py | 32 ---- tests/test_functional.py | 103 +++++++++-- 7 files changed, 302 insertions(+), 353 deletions(-) create mode 100644 bitsandbytes/backends/cpu_xpu_common.py delete mode 100644 bitsandbytes/backends/xpu.py delete mode 100644 examples/int8_inference_huggingface_cpu.py diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 0dae37e8d..48144a870 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -17,20 +17,15 @@ from .nn import modules from .backends import register_backend +from .backends.cpu import CPUBackend +register_backend("cpu", CPUBackend) + if lib and lib.compiled_with_cuda: from .backends.cuda import CUDABackend from .optim import adam register_backend("cuda", CUDABackend()) -elif torch.xpu.is_available(): - from .backends.xpu import XPUBackend - register_backend("xpu", XPUBackend) - -else: - from .backends.cpu import CPUBackend - register_backend("cpu", CPUBackend) - __pdoc__ = { "libbitsandbytes": False, "optim.optimizer.Optimizer8bit": False, diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py index 31bb52945..82c411166 100644 --- a/bitsandbytes/backends/cpu.py +++ b/bitsandbytes/backends/cpu.py @@ -1,4 +1,9 @@ import torch +from .cpu_xpu_common import ( + double_quant_impl, + igemmlt_impl, + mm_dequant_impl, +) Tensor = torch.Tensor @@ -17,181 +22,6 @@ def assert_on_cpu(tensors): return on_cpu -@torch.compile(dynamic=True, options={"fx_graph_cache": True}) -def double_quant_common( - A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 -): - """ - Find absolute max valus of each row/column of a tensor, and symmetrically quantize it to int8. - If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in - the original tensor and they are kept in COO format: (rows, cols, valus) - If threashold == 0.0, there are no outliers. - Args: - A The tensor to be analyzed and quantized. - col_stats Absolute max values of each column of A. If it is not None, use the values directly. - Otherwise, find the values. - row_stats Absolute max values of each row of A. If it is not None, use the values directly. - Otherwise, find the values. - out_col Output buffer for the result quantized per column if it is not None - out_row Output buffer for the result quantized per row if it is not None - threshold The threshold for finding outliers if it is > 0.0. Otherwise it has no effect. - Return: - A tuple of output quantized per row, output quantized per column, absolute max values of - each row of A, absolute max values of each column of A, outliers in COO format - """ - from ..functional import COOSparseTensor - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - assert A.dim() == 2, f"double_quant: Input tensor should be 2d or 3d but got {A.dim()}d" - rows = A.shape[0] - A = A.reshape(rows, cols) - - coo_tensor = None - - def get_row_col_stats(A): - row_stats = torch.max(torch.abs(A), 1).values # absolute max of each row - col_stats = torch.max(torch.abs(A), 0).values # absolute max of each col - return row_stats, col_stats - - def quant_to_int8(A, stats): - return torch.clamp(torch.round(A / stats * 127).to(torch.int8), -128, 127) - - if threshold == 0.0: - if row_stats is None or col_stats is None: - row_stats, col_stats = get_row_col_stats(A) - else: - outlier_indices = torch.abs(A) > threshold # find outliers - outlier_coord = outlier_indices.nonzero() # get outlier coordinates - outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor - outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor - outlier_values = A[outlier_indices] # outlier values for COO sparse tensor - coo_tensor = COOSparseTensor( - A.shape[0], A.shape[1], outlier_values.numel(), outlier_rows.int(), outlier_cols.int(), outlier_values - ) - if row_stats is None or col_stats is None: - A[outlier_indices] = 0 # zero out outliers - row_stats, col_stats = get_row_col_stats(A) - A[outlier_indices] = outlier_values # restore outliers for later use - - quant_by_row = quant_to_int8(A, row_stats.unsqueeze(-1)) - quant_by_col = quant_to_int8(A, col_stats.unsqueeze(0)) - if out_row is not None: - out_row.copy_(quant_by_row) - else: - out_row = quant_by_row - if out_col is not None: - out_col.copy_(quant_by_col) - else: - out_col = quant_by_col - return out_row, out_col, row_stats, col_stats, coo_tensor - - -def igemmlt_common( - A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32 -): - """ - Do GEMMM computation. Data type: int8 * int8 -> int32. - Args: - A Activation of linear, data type is int8 - B Weight of linear, data type is int8 - SA Not used for CPU/XPU - SB Not used for CPU/XPU - out Specified output tensor if it is not None - Sout Not used for CPU/XPU but returned as is - dtype Data type of output - Return: - A tuple of GEMM result in dtype and Sout - """ - assert A.dtype == torch.int8 - assert B.dtype == torch.int8 - if out is not None: - assert out.dtype == dtype - - dimsA = A.ndim - dimsB = B.ndim - shapeA = A.shape - shapeB = B.shape - assert dimsA in [2, 3], 'Only two or three dimensional matrices are supported for argument A' - assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' - - if dimsA == 2: - m = shapeA[0] - elif dimsA == 3: - m = shapeA[0] * shapeA[1] - n = shapeB[0] - k = shapeA[-1] - assert shapeA[-1] == shapeB[-1], f'Shapes of A and B do not match, got {shapeA} and {shapeB}' - shapeOut = (shapeA[0], shapeA[1], n) if dimsA == 3 else (m, n) - - # if the tensor is empty, return a transformed empty tensor with the right dimensions - if shapeA[0] == 0 and dimsA == 2: - return torch.empty((0, n), device=A.device, dtype=A.dtype) - elif shapeA[1] == 0 and dimsA == 3: - return torch.empty(tuple(shapeA[:2] + [n]), device=A.device, dtype=A.dtype) - - A_reshaped = A.reshape(m, k) - - if assert_on_cpu([A_reshaped, B]): - C = torch._int_mm(A_reshaped, B.T).to(dtype) - else: - C = torch.nn.functional.linear(A_reshaped, B).to(dtype) - if C.ndim != dimsA: - C = C.reshape(shapeOut) - if out is not None: - out.copy_(C) - else: - out = C - - return out, Sout - - -@torch.compile(dynamic=True, options={"fx_graph_cache": True}) -def mm_dequant_common( - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None, - compute_dtype=torch.float32, - output_dtype=torch.float32 -): - """ - Dequant and add bias - out = A_int32 * (scale_A, scale_B) / 127 * 127 + bias - Args: - A The output of int8 gemm, whose dtype is int32 - quant_state Not used for CPU - row_stats Absolute max value of each row of input (A) of gemm - col_stats Absolute max value of each row of weight (B) of gemm - out Output buffer - new_row_stats Not used for CPU/XPU - new_col_stats Not used for CPU/XPU - bias Bias of linear - compute_dtype Data type for computation - output_dtype Data type for output - Return: - The result - """ - assert A.dtype == torch.int32 - out_shape = A.shape - if len(out_shape) == 3: - out_shape = (out_shape[0] * out_shape[1], out_shape[2]) - - A_reshaped = A.reshape(out_shape).to(compute_dtype) - row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) - col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) - out = A_reshaped * row_stats * col_stats / (127 * 127) - if bias is not None: - out = out + bias.to(compute_dtype) - out = out.to(output_dtype) - return out - - class CPUBackend: mm_dequant_compute_dtype = torch.bfloat16 mm_dequant_output_dtype = torch.bfloat16 @@ -201,7 +31,7 @@ def double_quant( cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 ): assert_on_cpu([A, col_stats, row_stats, out_col, out_row]) - return double_quant_common(A, col_stats, row_stats, out_col, out_row) + return double_quant_impl(A, col_stats, row_stats, out_col, out_row) @classmethod def transform(cls, A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None): @@ -226,7 +56,7 @@ def transform(cls, A, to_order=None, from_order='row', out=None, transpose=False @classmethod def igemmlt(cls, A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32): assert_on_cpu([A, B]) - return igemmlt_common(A, B, SA, SB, out, Sout, dtype) + return igemmlt_impl(A, B, SA, SB, out, Sout, dtype) @classmethod def mm_dequant( @@ -241,7 +71,7 @@ def mm_dequant( bias=None ): assert_on_cpu([A, row_stats, col_stats, out, bias]) - return mm_dequant_common( + return mm_dequant_impl( A, quant_state, row_stats, diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py new file mode 100644 index 000000000..e6bc59075 --- /dev/null +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -0,0 +1,203 @@ +import torch + + +Tensor = torch.Tensor + + +def _torch_version_prereq(major, minor): + ver_major = int(torch.__version__.split('.')[0]) + ver_minor = int(torch.__version__.split('.')[1]) + return ver_major * 32 + ver_minor >= major * 32 + minor + + +def _maybe_torch_compile(func): + # torch.compile requires pytorch >= 2.0 + if _torch_version_prereq(2, 0): + options = {} + # fx_graph_cache requires pytorch >= 2.2 + if _torch_version_prereq(2, 2): + options.update({"fx_graph_cache": True}) + return torch.compile(func, dynamic=True, options=options) + return func + + +@_maybe_torch_compile +def double_quant_impl( + A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 +): + """ + Find absolute max valus of each row/column of a tensor, and symmetrically quantize it to int8. + If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in + the original tensor and they are kept in COO format: (rows, cols, valus) + If threashold == 0.0, there are no outliers. + Args: + A The tensor to be analyzed and quantized. + col_stats Absolute max values of each column of A. If it is not None, use the values directly. + Otherwise, find the values. + row_stats Absolute max values of each row of A. If it is not None, use the values directly. + Otherwise, find the values. + out_col Output buffer for the result quantized per column if it is not None + out_row Output buffer for the result quantized per row if it is not None + threshold The threshold for finding outliers if it is > 0.0. Otherwise it has no effect. + Return: + A tuple of output quantized per row, output quantized per column, absolute max values of + each row of A, absolute max values of each column of A, outliers in COO format + """ + from ..functional import COOSparseTensor + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + assert A.dim() == 2, f"double_quant: Input tensor should be 2d or 3d but got {A.dim()}d" + rows = A.shape[0] + A = A.reshape(rows, cols) + + coo_tensor = None + + def get_row_col_stats(A): + row_stats = torch.max(torch.abs(A), 1).values # absolute max of each row + col_stats = torch.max(torch.abs(A), 0).values # absolute max of each col + return row_stats, col_stats + + def quant_to_int8(A, stats): + return torch.clamp(torch.round(A * (127.0 / stats)), -128, 127).to(torch.int8) + + if threshold == 0.0: + if row_stats is None or col_stats is None: + row_stats, col_stats = get_row_col_stats(A) + else: + outlier_indices = torch.abs(A) > threshold # find outliers + outlier_coord = outlier_indices.nonzero() # get outlier coordinates + outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor + outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor + outlier_values = A[outlier_indices] # outlier values for COO sparse tensor + coo_tensor = COOSparseTensor( + A.shape[0], A.shape[1], outlier_values.numel(), outlier_rows.int(), outlier_cols.int(), outlier_values + ) + if row_stats is None or col_stats is None: + A[outlier_indices] = 0 # zero out outliers + row_stats, col_stats = get_row_col_stats(A) + A[outlier_indices] = outlier_values # restore outliers for later use + + quant_by_row = quant_to_int8(A, row_stats.unsqueeze(-1)) + quant_by_col = quant_to_int8(A, col_stats.unsqueeze(0)) + if out_row is not None: + out_row.copy_(quant_by_row) + else: + out_row = quant_by_row + if out_col is not None: + out_col.copy_(quant_by_col) + else: + out_col = quant_by_col + # Return float stats to align with CUDA impl + return out_row, out_col, row_stats.float(), col_stats.float(), coo_tensor + + +def igemmlt_impl( + A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32 +): + """ + Do GEMMM computation. Data type: int8 * int8 -> int32. + Args: + A Activation of linear, data type is int8 + B Weight of linear, data type is int8 + SA Not used for CPU/XPU + SB Not used for CPU/XPU + out Specified output tensor if it is not None + Sout Not used for CPU/XPU but returned as is + dtype Data type of output + Return: + A tuple of GEMM result in dtype and Sout + """ + assert A.dtype == torch.int8 + assert B.dtype == torch.int8 + if out is not None: + assert out.dtype == dtype + + dimsA = A.ndim + dimsB = B.ndim + shapeA = A.shape + shapeB = B.shape + assert dimsA in [2, 3], 'Only two or three dimensional matrices are supported for argument A' + assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' + + if dimsA == 2: + m = shapeA[0] + elif dimsA == 3: + m = shapeA[0] * shapeA[1] + if shapeA[-1] == shapeB[0]: + B = B.t() + shapeB = B.shape + else: + assert shapeA[-1] == shapeB[-1], f'Shapes of A and B do not match, got {shapeA} and {shapeB}' + n = shapeB[0] + k = shapeA[-1] + + # if the tensor is empty, return a transformed empty tensor with the right dimensions + if shapeA[0] == 0 and dimsA == 2: + return torch.empty((0, n), device=A.device, dtype=A.dtype) + elif shapeA[1] == 0 and dimsA == 3: + return torch.empty(tuple(shapeA[:2] + [n]), device=A.device, dtype=A.dtype) + + A_reshaped = A.reshape(m, k) + + # torch._int_mm is available on CPU since torch 2.4 + if _torch_version_prereq(2, 4): + C = torch._int_mm(A_reshaped, B.T).to(dtype) + else: + C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype) + if C.ndim != dimsA: + assert dimsA == 3 + shapeOut = (shapeA[0], m // shapeA[0], C.shape[-1]) + C = C.reshape(shapeOut) + if out is not None: + out.copy_(C) + else: + out = C + + return out, Sout + + +@_maybe_torch_compile +def mm_dequant_impl( + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None, + compute_dtype=torch.float32, + output_dtype=torch.float32 +): + """ + Dequant and add bias + out = A_int32 * (abs_max_A * abs_max_B) / 127 * 127 + bias + Args: + A The output of int8 gemm, whose dtype is int32 + quant_state Not used for CPU + row_stats Absolute max value of each row of input (A) of gemm + col_stats Absolute max value of each row of weight (B) of gemm + out Output buffer + new_row_stats Not used for CPU/XPU + new_col_stats Not used for CPU/XPU + bias Bias of linear + compute_dtype Data type for computation + output_dtype Data type for output + Return: + The result + """ + assert A.dtype == torch.int32 + out_shape = A.shape + if len(out_shape) == 3: + out_shape = (out_shape[0] * out_shape[1], out_shape[2]) + + A_reshaped = A.reshape(out_shape).to(compute_dtype) + row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) + col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) + out = A_reshaped * row_stats * col_stats / (127 * 127) + if bias is not None: + out = out + bias.to(compute_dtype) + out = out.to(output_dtype) + return out diff --git a/bitsandbytes/backends/xpu.py b/bitsandbytes/backends/xpu.py deleted file mode 100644 index 9ee8a09dc..000000000 --- a/bitsandbytes/backends/xpu.py +++ /dev/null @@ -1,118 +0,0 @@ -# For Intel GPU (xpu is the device name for Intel GPU in PyTorch) -import torch -from .cpu import ( - double_quant_common, - igemmlt_common, - mm_dequant_common, -) - -Tensor = torch.Tensor - -def assert_on_xpu(tensors): - on_xpu = True - for t in tensors: - if t is None: continue # NULL pointers are fine - on_xpu &= (t.device.type == 'xpu') - if not on_xpu: - raise TypeError( - 'All input tensors need to be on XPU, but found some tensors to not be on XPU:\n' \ - f' {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}' - ) - return on_xpu - - -class XPUBackend: - mm_dequant_compute_dtype = torch.half - mm_dequant_output_dtype = torch.half - - @classmethod - @torch.compile(dynamic=True, options={"fx_graph_cache": True}) - def double_quant( - cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 - ): - assert_on_xpu([A, col_stats, row_stats, out_col, out_row]) - return double_quant_common(A, col_stats, row_stats, out_col, out_row) - - @classmethod - def transform(cls, A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None): - """ - Transform tensor A to to_order. It is originally designed for CUDA. - For XPU, it returns the original tensor if transpose=False. - Otherwise, it returns the transpose of A - """ - assert_on_xpu([A, out]) - if transpose: - if out is not None: - out.copy_(A.T) - else: - out = A.T - else: - if out is not None: - out.copy_(A) - else: - out = A - return out, state - - @classmethod - def igemmlt(cls, A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32): - assert_on_xpu([A, B]) - return igemmlt_common(A, B, SA, SB, out, Sout, dtype) - - @classmethod - @torch.compile(dynamic=True, options={"fx_graph_cache": True}) - def mm_dequant( - cls, - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None - ): - assert_on_xpu([A, row_stats, col_stats, out, bias]) - return mm_dequant_common( - A, - quant_state, - row_stats, - col_stats, - out, - new_row_stats, - new_col_stats, - bias, - cls.mm_dequant_compute_dtype, - cls.mm_dequant_output_dtype - ) - - @classmethod - def extract_outliers(cls, A, SA, idx): - """ - Extract columns of A by idx - """ - assert_on_xpu([A]) - return A[:, idx].contiguous() - - @classmethod - def quantize_4bit( - cls, - A: Tensor, - absmax: Tensor = None, - out: Tensor = None, - blocksize=64, - compress_statistics=False, - quant_type="fp4", - ) -> Tensor: - assert False, "quantize_4bit not yet implemented for XPU backend" - - @classmethod - def dequantize_4bit( - cls, - A: Tensor, - quant_state = None, - absmax: Tensor = None, - out: Tensor = None, - blocksize: int = 64, - quant_type="fp4", - ) -> Tensor: - assert False, "dequantize_4bit not yet implemented for XPU backend" diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index baba76963..54a161f7a 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2177,7 +2177,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): return xq, max1 elif quant_type in ["vector", "row"]: max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - xq = torch.round(x * (C / max1)).to(torch.int8) + xq = torch.clamp(torch.round(x * (C / max1)), -128, 127).to(torch.int8) return xq, max1 elif quant_type == "zeropoint": dtype = x.dtype diff --git a/examples/int8_inference_huggingface_cpu.py b/examples/int8_inference_huggingface_cpu.py deleted file mode 100644 index b41605893..000000000 --- a/examples/int8_inference_huggingface_cpu.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig -import time - -MAX_NEW_TOKENS = 64 -model_id = "facebook/opt-1.3b" - -text = 'Hamburg is in which country?\n' -tokenizer = AutoTokenizer.from_pretrained(model_id) -input_ids = tokenizer(text, return_tensors="pt").input_ids - -print('Loading model {}...'.format(model_id)) -quantization_config = BitsAndBytesConfig(load_in_8bit=True) -model = AutoModelForCausalLM.from_pretrained( - model_id, - device_map='auto', - quantization_config=quantization_config, - torch_dtype=torch.bfloat16 -) -print('model dtype = {}'.format(model.dtype)) - -with torch.no_grad(): - t0 = time.time() - generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) - latency = time.time() - t0 - result = "| latency: " + str(round(latency * 1000, 3)) + " ms |" - print('+' + '-' * (len(result) - 2) + '+') - print(result) - print('+' + '-' * (len(result) - 2) + '+') - -output = tokenizer.decode(generated_ids[0], skip_special_tokens=True) -print(f"output: {output}") diff --git a/tests/test_functional.py b/tests/test_functional.py index b9f1a6ead..cf4088c00 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -576,28 +576,37 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans @pytest.mark.parametrize("dim4", get_test_dims(32, 1024, n=1), ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) -def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): +@pytest.mark.parametrize("device", ("cuda", "cpu"), ids=id_formatter("device")) +def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb, device): for i in range(k): if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim3), device=device).to(torch.int8) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8) - B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device=device).to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device=device).to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) A2, SA = F.transform(A, "col32") B2, SB = F.transform(B, "col_turing") C2, SC = F.igemmlt(A2, B2, SA, SB) - C3, S = F.nvidia_transform(C2, "row", state=SC) + if device == "cpu": + assert SC is None + if device == "cuda": + C3, S = F.nvidia_transform(C2, "row", state=SC) + else: + C3, S = C2, None torch.testing.assert_close(C1, C3.float()) # transpose - B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=(dim3, dim4), device=device).to(torch.int8) C1 = torch.matmul(A.float(), B.float()) B2t, SBt = F.transform(B, "col_turing", transpose=True) C2, SC = F.igemmlt(A2, B2t, SA, SBt) - C3, S = F.nvidia_transform(C2, "row", state=SC) + if device == "cuda": + C3, S = F.nvidia_transform(C2, "row", state=SC) + else: + C3, S = C2, None torch.testing.assert_close(C1, C3.float()) @@ -846,6 +855,33 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n)) +@pytest.mark.parametrize("dim1", get_test_dims(64, 256, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim4", get_test_dims(64, 1024, n=2), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) +@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) +def test_dequant_mm_cpu(dim1, dim4, dims, has_bias): + inner = torch.randint(1, 128, size=(1,)).item() + bias = None + if has_bias: + bias = torch.randn(dim4, device="cpu", dtype=torch.bfloat16) + for i in range(1): + A = torch.randn(dim1, inner, device="cpu") + B = torch.randn(dim4, inner, device="cpu") + + A1, maxA = F.vectorwise_quant(A, dim=1) + B1, maxB = F.vectorwise_quant(B, dim=1) + + C2, SC = F.igemmlt(A1, B1, SA=None, SB=None) + assert SC is None + + C3 = F.vectorwise_mm_dequant(C2.bfloat16(), maxA, maxB.t()) + if has_bias: + C3 += bias + + C4 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) + torch.testing.assert_close(C3.float(), C4.float(), atol=0.05, rtol=0.1) + + @pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @@ -892,9 +928,13 @@ def test_colrow_absmax(dim1, dim2, dims): @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) -def test_double_quant(dim1, dim2): +@pytest.mark.parametrize("device", ["cuda","cpu"], ids=id_formatter("device")) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16], ids=id_formatter("dtype")) +def test_double_quant(dim1, dim2, device, dtype): + if device == "cuda" and dtype == torch.bfloat16: + pytest.skip("BFloat16 not supported on CUDA") for i in range(k): - A = torch.randn(dim1, dim2, device="cuda").half() + A = torch.randn(dim1, dim2, device=device).to(dtype) out_col1, Scol = F.vectorwise_quant(A, dim=0) out_row1, Srow = F.vectorwise_quant(A, dim=1) @@ -1125,6 +1165,33 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): torch.testing.assert_close(out1, out2) +@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims")) +@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype) +@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) +@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut")) +@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) +def test_transform_cpu(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): + for i in range(k): + if dims == 2: + A = torch.randint(10, 99, size=(dim1, dim2), device="cpu").to(dtype) + elif dims == 3: + A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cpu").to(dtype) + + A.view(-1)[-1] = -1 + if transpose: + out1 = A.t().contiguous() + else: + out1 = A + out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose) + + assert S2 is None + + torch.testing.assert_close(out1, out2) + + def test_overflow(): formatB = F.get_special_format_str() print(formatB) @@ -1141,10 +1208,14 @@ def test_overflow(): @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) -def test_coo_double_quant(dim1, dim2): +@pytest.mark.parametrize("device", ["cuda","cpu"], ids=id_formatter("device")) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16], ids=id_formatter("dtype")) +def test_coo_double_quant(dim1, dim2, device, dtype): + if device == "cuda" and dtype == torch.bfloat16: + pytest.skip("BFloat16 not supported on CUDA") threshold = 3.00 for i in range(k): - A = torch.randn(dim1, dim2, device="cuda").half() + A = torch.randn(dim1, dim2, device=device).to(dtype) idx = torch.abs(A) >= threshold CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) @@ -1157,7 +1228,7 @@ def test_coo_double_quant(dim1, dim2): torch.testing.assert_close(A1, A2) A1 = A * (idx == 0) - A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() + A2 = (CA.float() * statsA.unsqueeze(1) / 127).to(dtype) torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) @@ -1729,12 +1800,12 @@ def quant_zp(x): print(err1, err2, err3, err4, err5, err6) -def test_extract_outliers(): +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +def test_extract_outliers(device): for i in range(k): shapeA = (4096, 4096 * 4) - idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda() - # idx = torch.Tensor([0]).int().cuda() - A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).to(device=device) + A = torch.randint(-128, 127, size=shapeA, device=device).to(torch.int8) outliers1 = A[:, idx.long()] CA, SA = F.transform(A, "col_turing") From 8d0b695d8aadaa225b39352068a8dc7d999c4eae Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 15 Apr 2024 02:40:40 -0700 Subject: [PATCH 03/12] Fix igemmlt correctness issue --- bitsandbytes/backends/cpu_xpu_common.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index e6bc59075..5be83d8e3 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -125,13 +125,9 @@ def igemmlt_impl( m = shapeA[0] elif dimsA == 3: m = shapeA[0] * shapeA[1] - if shapeA[-1] == shapeB[0]: - B = B.t() - shapeB = B.shape - else: - assert shapeA[-1] == shapeB[-1], f'Shapes of A and B do not match, got {shapeA} and {shapeB}' n = shapeB[0] k = shapeA[-1] + assert shapeA[-1] == shapeB[-1], f'Shapes of A and B do not match, got {shapeA} and {shapeB}' # if the tensor is empty, return a transformed empty tensor with the right dimensions if shapeA[0] == 0 and dimsA == 2: From 67d86611d5b4e34f5d8e8ebc1c1e08dddee671ae Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 17 Apr 2024 23:06:57 -0700 Subject: [PATCH 04/12] Bug fix for double_quant --- bitsandbytes/backends/cpu.py | 2 +- bitsandbytes/backends/cpu_xpu_common.py | 11 +++++++++-- bitsandbytes/nn/modules.py | 19 ------------------- tests/test_functional.py | 4 +++- 4 files changed, 13 insertions(+), 23 deletions(-) diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py index 82c411166..5183e6485 100644 --- a/bitsandbytes/backends/cpu.py +++ b/bitsandbytes/backends/cpu.py @@ -31,7 +31,7 @@ def double_quant( cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 ): assert_on_cpu([A, col_stats, row_stats, out_col, out_row]) - return double_quant_impl(A, col_stats, row_stats, out_col, out_row) + return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold) @classmethod def transform(cls, A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None): diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 5be83d8e3..7c7927c88 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -1,4 +1,5 @@ import torch +import warnings Tensor = torch.Tensor @@ -66,7 +67,7 @@ def quant_to_int8(A, stats): if row_stats is None or col_stats is None: row_stats, col_stats = get_row_col_stats(A) else: - outlier_indices = torch.abs(A) > threshold # find outliers + outlier_indices = torch.abs(A) >= threshold # find outliers outlier_coord = outlier_indices.nonzero() # get outlier coordinates outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor @@ -77,10 +78,13 @@ def quant_to_int8(A, stats): if row_stats is None or col_stats is None: A[outlier_indices] = 0 # zero out outliers row_stats, col_stats = get_row_col_stats(A) - A[outlier_indices] = outlier_values # restore outliers for later use quant_by_row = quant_to_int8(A, row_stats.unsqueeze(-1)) quant_by_col = quant_to_int8(A, col_stats.unsqueeze(0)) + + if coo_tensor is not None: + A[outlier_indices] = outlier_values # restore outliers for later use + if out_row is not None: out_row.copy_(quant_by_row) else: @@ -189,6 +193,9 @@ def mm_dequant_impl( if len(out_shape) == 3: out_shape = (out_shape[0] * out_shape[1], out_shape[2]) + if compute_dtype not in [torch.float32, torch.bfloat16]: + warnings.warn(f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use float instead") + compute_dtype = torch.float32 A_reshaped = A.reshape(out_shape).to(compute_dtype) row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index bcba8b3d2..c0efafec0 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -594,19 +594,6 @@ def cpu(self): setattr(self, "SCB", SCB) return self - def xpu(self): - # we store the 8-bit rows-major weight - B = self.data.contiguous().half().cpu() - CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) - if CBt is not None: - del CBt - if SCBt is not None: - del SCBt - self.data = CB - setattr(self, "CB", CB) - setattr(self, "SCB", SCB) - return self - @overload def to( self: T, @@ -626,12 +613,6 @@ def to(self, *args, **kwargs): if device is not None and device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) - elif ( - device is not None - and device.type == "xpu" - and self.data.dtype != torch.int8 - ): - return self.xpu() elif ( device is not None and device.type == "cpu" diff --git a/tests/test_functional.py b/tests/test_functional.py index cf4088c00..ba1e32d77 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1221,6 +1221,8 @@ def test_coo_double_quant(dim1, dim2, device, dtype): CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + if idx.sum() > 0: + assert coo_tensor is not None if coo_tensor is not None: A1 = A * idx A2 = torch.zeros_like(A) @@ -1229,7 +1231,7 @@ def test_coo_double_quant(dim1, dim2, device, dtype): A1 = A * (idx == 0) A2 = (CA.float() * statsA.unsqueeze(1) / 127).to(dtype) - torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + torch.testing.assert_close(A1, A2, rtol=0.05, atol=1.5e-2) @pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1")) From 92900f6cc82d0010909aa6eadc13bdc497fb36f9 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 18 Apr 2024 04:24:23 -0700 Subject: [PATCH 05/12] Remove torch.compile for double_quant --- bitsandbytes/__init__.py | 1 - bitsandbytes/backends/cpu_xpu_common.py | 2 +- bitsandbytes/nn/modules.py | 5 ++--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 48144a870..cc7812e4e 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -25,7 +25,6 @@ from .optim import adam register_backend("cuda", CUDABackend()) - __pdoc__ = { "libbitsandbytes": False, "optim.optimizer.Optimizer8bit": False, diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 7c7927c88..c6573b6c0 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -22,7 +22,7 @@ def _maybe_torch_compile(func): return func -@_maybe_torch_compile +# Don't use torch.compile for now due to PyTorch issue https://github.com/pytorch/pytorch/issues/124382 def double_quant_impl( A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 ): diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index c0efafec0..9295e4c70 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -611,11 +611,10 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type == "cuda" and self.data.device.type == "cpu": + if device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) elif ( - device is not None - and device.type == "cpu" + device.type == "cpu" and self.data.dtype != torch.int8 ): return self.cpu() From 717245d4f377484de5bf67c22c58ac13fc2d02cc Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 18 Apr 2024 19:56:32 -0700 Subject: [PATCH 06/12] refine pytest.skip message --- tests/test_functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index ba1e32d77..566d8429f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -932,7 +932,7 @@ def test_colrow_absmax(dim1, dim2, dims): @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16], ids=id_formatter("dtype")) def test_double_quant(dim1, dim2, device, dtype): if device == "cuda" and dtype == torch.bfloat16: - pytest.skip("BFloat16 not supported on CUDA") + pytest.skip("bfloat16 is not implemented for this operation on CUDA backend") for i in range(k): A = torch.randn(dim1, dim2, device=device).to(dtype) out_col1, Scol = F.vectorwise_quant(A, dim=0) @@ -1212,7 +1212,7 @@ def test_overflow(): @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16], ids=id_formatter("dtype")) def test_coo_double_quant(dim1, dim2, device, dtype): if device == "cuda" and dtype == torch.bfloat16: - pytest.skip("BFloat16 not supported on CUDA") + pytest.skip("bfloat16 is not implemented for this operation on CUDA backend") threshold = 3.00 for i in range(k): A = torch.randn(dim1, dim2, device=device).to(dtype) From 93e04b5cfa56e206a9699a60a4c35972f23e69b6 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 24 Apr 2024 18:37:02 -0700 Subject: [PATCH 07/12] Fix lint issues --- bitsandbytes/__init__.py | 5 +-- bitsandbytes/autograd/_functions.py | 4 +-- bitsandbytes/backends/cpu.py | 31 +++++++----------- bitsandbytes/backends/cpu_xpu_common.py | 43 ++++++++++++------------- bitsandbytes/functional.py | 2 +- bitsandbytes/nn/modules.py | 9 ++---- tests/test_functional.py | 4 +-- 7 files changed, 42 insertions(+), 56 deletions(-) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index cc7812e4e..dc9094a2c 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch + from . import research, utils from .autograd._functions import ( MatmulLtState, @@ -13,11 +14,11 @@ matmul_cublas, mm_cublas, ) +from .backends import register_backend +from .backends.cpu import CPUBackend from .cextension import lib from .nn import modules -from .backends import register_backend -from .backends.cpu import CPUBackend register_backend("cpu", CPUBackend) if lib and lib.compiled_with_cuda: diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 67b8b6b87..08d2d9fa6 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -217,7 +217,7 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: """check if this device supports the optimized int8 kernel""" - if device == torch.device('cpu'): + if device == torch.device("cpu"): return True if torch.cuda.get_device_capability(device=device) < (7, 5): return False @@ -315,7 +315,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): # Cast A to fp16 A_dtype = torch.float16 - if A.device == torch.device('cpu'): + if A.device == torch.device("cpu"): A_dtype = torch.bfloat16 if A.dtype != A_dtype: warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization") diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py index 5183e6485..fe77005a0 100644 --- a/bitsandbytes/backends/cpu.py +++ b/bitsandbytes/backends/cpu.py @@ -1,23 +1,24 @@ import torch + from .cpu_xpu_common import ( double_quant_impl, igemmlt_impl, mm_dequant_impl, ) - Tensor = torch.Tensor def assert_on_cpu(tensors): on_cpu = True for t in tensors: - if t is None: continue # NULL pointers are fine - on_cpu &= (t.device.type == 'cpu') + if t is None: + continue # NULL pointers are fine + on_cpu &= t.device.type == "cpu" if not on_cpu: raise TypeError( - 'All input tensors need to be on CPU, but found some tensors to not be on CPU:\n' \ - f' {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}' + "All input tensors need to be on CPU, but found some tensors to not be on CPU:\n" + f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}" ) return on_cpu @@ -27,14 +28,12 @@ class CPUBackend: mm_dequant_output_dtype = torch.bfloat16 @classmethod - def double_quant( - cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 - ): + def double_quant(cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): assert_on_cpu([A, col_stats, row_stats, out_col, out_row]) return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold) @classmethod - def transform(cls, A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None): + def transform(cls, A, to_order=None, from_order="row", out=None, transpose=False, state=None, ld=None): """ Transform tensor A to to_order. It is originally designed for CUDA. For CPU, it returns the original tensor if transpose=False. @@ -60,15 +59,7 @@ def igemmlt(cls, A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32) @classmethod def mm_dequant( - cls, - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None + cls, A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None ): assert_on_cpu([A, row_stats, col_stats, out, bias]) return mm_dequant_impl( @@ -81,7 +72,7 @@ def mm_dequant( new_col_stats, bias, cls.mm_dequant_compute_dtype, - cls.mm_dequant_output_dtype + cls.mm_dequant_output_dtype, ) @classmethod @@ -108,7 +99,7 @@ def quantize_4bit( def dequantize_4bit( cls, A: Tensor, - quant_state = None, + quant_state=None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index c6573b6c0..c4bd25a04 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -1,13 +1,13 @@ -import torch import warnings +import torch Tensor = torch.Tensor def _torch_version_prereq(major, minor): - ver_major = int(torch.__version__.split('.')[0]) - ver_minor = int(torch.__version__.split('.')[1]) + ver_major = int(torch.__version__.split(".")[0]) + ver_minor = int(torch.__version__.split(".")[1]) return ver_major * 32 + ver_minor >= major * 32 + minor @@ -23,14 +23,12 @@ def _maybe_torch_compile(func): # Don't use torch.compile for now due to PyTorch issue https://github.com/pytorch/pytorch/issues/124382 -def double_quant_impl( - A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 -): +def double_quant_impl(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): """ Find absolute max valus of each row/column of a tensor, and symmetrically quantize it to int8. If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in the original tensor and they are kept in COO format: (rows, cols, valus) - If threashold == 0.0, there are no outliers. + If threshold == 0.0, there are no outliers. Args: A The tensor to be analyzed and quantized. col_stats Absolute max values of each column of A. If it is not None, use the values directly. @@ -45,6 +43,7 @@ def double_quant_impl( each row of A, absolute max values of each column of A, outliers in COO format """ from ..functional import COOSparseTensor + cols = A.shape[-1] if len(A.shape) == 3: rows = A.shape[0] * A.shape[1] @@ -56,8 +55,8 @@ def double_quant_impl( coo_tensor = None def get_row_col_stats(A): - row_stats = torch.max(torch.abs(A), 1).values # absolute max of each row - col_stats = torch.max(torch.abs(A), 0).values # absolute max of each col + row_stats = torch.max(torch.abs(A), 1).values # absolute max of each row + col_stats = torch.max(torch.abs(A), 0).values # absolute max of each col return row_stats, col_stats def quant_to_int8(A, stats): @@ -67,23 +66,23 @@ def quant_to_int8(A, stats): if row_stats is None or col_stats is None: row_stats, col_stats = get_row_col_stats(A) else: - outlier_indices = torch.abs(A) >= threshold # find outliers - outlier_coord = outlier_indices.nonzero() # get outlier coordinates - outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor - outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor - outlier_values = A[outlier_indices] # outlier values for COO sparse tensor + outlier_indices = torch.abs(A) >= threshold # find outliers + outlier_coord = outlier_indices.nonzero() # get outlier coordinates + outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor + outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor + outlier_values = A[outlier_indices] # outlier values for COO sparse tensor coo_tensor = COOSparseTensor( A.shape[0], A.shape[1], outlier_values.numel(), outlier_rows.int(), outlier_cols.int(), outlier_values ) if row_stats is None or col_stats is None: - A[outlier_indices] = 0 # zero out outliers + A[outlier_indices] = 0 # zero out outliers row_stats, col_stats = get_row_col_stats(A) quant_by_row = quant_to_int8(A, row_stats.unsqueeze(-1)) quant_by_col = quant_to_int8(A, col_stats.unsqueeze(0)) if coo_tensor is not None: - A[outlier_indices] = outlier_values # restore outliers for later use + A[outlier_indices] = outlier_values # restore outliers for later use if out_row is not None: out_row.copy_(quant_by_row) @@ -97,9 +96,7 @@ def quant_to_int8(A, stats): return out_row, out_col, row_stats.float(), col_stats.float(), coo_tensor -def igemmlt_impl( - A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32 -): +def igemmlt_impl(A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32): """ Do GEMMM computation. Data type: int8 * int8 -> int32. Args: @@ -122,8 +119,8 @@ def igemmlt_impl( dimsB = B.ndim shapeA = A.shape shapeB = B.shape - assert dimsA in [2, 3], 'Only two or three dimensional matrices are supported for argument A' - assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' + assert dimsA in [2, 3], "Only two or three dimensional matrices are supported for argument A" + assert dimsB == 2, "Only two dimensional matrices are supported for argument B" if dimsA == 2: m = shapeA[0] @@ -131,7 +128,7 @@ def igemmlt_impl( m = shapeA[0] * shapeA[1] n = shapeB[0] k = shapeA[-1] - assert shapeA[-1] == shapeB[-1], f'Shapes of A and B do not match, got {shapeA} and {shapeB}' + assert shapeA[-1] == shapeB[-1], f"Shapes of A and B do not match, got {shapeA} and {shapeB}" # if the tensor is empty, return a transformed empty tensor with the right dimensions if shapeA[0] == 0 and dimsA == 2: @@ -169,7 +166,7 @@ def mm_dequant_impl( new_col_stats=None, bias=None, compute_dtype=torch.float32, - output_dtype=torch.float32 + output_dtype=torch.float32, ): """ Dequant and add bias diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 54a161f7a..9828add30 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1945,7 +1945,7 @@ class COOSparseTensor: def __init__(self, rows, cols, nnz, rowidx, colidx, values): assert rowidx.dtype == torch.int32 assert colidx.dtype == torch.int32 - if values.device == torch.device('cpu'): + if values.device == torch.device("cpu"): assert values.dtype in [torch.bfloat16, torch.float] else: assert values.dtype == torch.float16 diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 9295e4c70..f2b5e34b8 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -590,8 +590,8 @@ def cpu(self): if SCBt is not None: del SCBt self.data = CB - setattr(self, "CB", CB) - setattr(self, "SCB", SCB) + self.CB = CB + self.SCB = SCB return self @overload @@ -613,10 +613,7 @@ def to(self, *args, **kwargs): if device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) - elif ( - device.type == "cpu" - and self.data.dtype != torch.int8 - ): + elif device.type == "cpu" and self.data.dtype != torch.int8: return self.cpu() else: new_param = Int8Params( diff --git a/tests/test_functional.py b/tests/test_functional.py index 566d8429f..94b4222c2 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -928,7 +928,7 @@ def test_colrow_absmax(dim1, dim2, dims): @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) -@pytest.mark.parametrize("device", ["cuda","cpu"], ids=id_formatter("device")) +@pytest.mark.parametrize("device", ["cuda", "cpu"], ids=id_formatter("device")) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16], ids=id_formatter("dtype")) def test_double_quant(dim1, dim2, device, dtype): if device == "cuda" and dtype == torch.bfloat16: @@ -1208,7 +1208,7 @@ def test_overflow(): @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) -@pytest.mark.parametrize("device", ["cuda","cpu"], ids=id_formatter("device")) +@pytest.mark.parametrize("device", ["cuda", "cpu"], ids=id_formatter("device")) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16], ids=id_formatter("dtype")) def test_coo_double_quant(dim1, dim2, device, dtype): if device == "cuda" and dtype == torch.bfloat16: From e1b60d3093759bca952b9aefea4ba1140c1e2340 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 25 Apr 2024 23:51:57 -0700 Subject: [PATCH 08/12] Fix backward --- bitsandbytes/autograd/_functions.py | 3 +++ bitsandbytes/functional.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 08d2d9fa6..7d570f28b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -94,6 +94,9 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) - :param tile_indices: reverse transformation indices, from get_inverse_transform_indices :return: contiguous row-major tensor """ + # CPU has no change on layout + if permuted_tensor.device.type == "cpu": + return permuted_tensor (rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles" tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t() diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9828add30..8fd62fd04 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1946,7 +1946,7 @@ def __init__(self, rows, cols, nnz, rowidx, colidx, values): assert rowidx.dtype == torch.int32 assert colidx.dtype == torch.int32 if values.device == torch.device("cpu"): - assert values.dtype in [torch.bfloat16, torch.float] + assert values.dtype in [torch.bfloat16, torch.half, torch.float] else: assert values.dtype == torch.float16 assert values.numel() == nnz From 95c29a63ba04be0ce48bc2031861753f4e8215c4 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sun, 5 May 2024 19:42:02 -0700 Subject: [PATCH 09/12] Fix lint issue --- bitsandbytes/backends/cpu_xpu_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index c4bd25a04..815349e46 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -25,9 +25,9 @@ def _maybe_torch_compile(func): # Don't use torch.compile for now due to PyTorch issue https://github.com/pytorch/pytorch/issues/124382 def double_quant_impl(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): """ - Find absolute max valus of each row/column of a tensor, and symmetrically quantize it to int8. + Find absolute max values of each row/column of a tensor, and symmetrically quantize it to int8. If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in - the original tensor and they are kept in COO format: (rows, cols, valus) + the original tensor and they are kept in COO format: (rows, cols, values) If threshold == 0.0, there are no outliers. Args: A The tensor to be analyzed and quantized. From b0dec0a55c3464ed97a8cadeb1ba0d43f2704c25 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 7 May 2024 00:48:29 -0700 Subject: [PATCH 10/12] Update bitsandbytes/backends/cpu_xpu_common.py --- bitsandbytes/backends/cpu_xpu_common.py | 26 +++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 815349e46..5a0f0f9d5 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -2,6 +2,16 @@ import torch +try: + # to support Intel CPU/GPU (XPU) backend + import intel_extension_for_pytorch as ipex + ipex_cpu = ipex if ipex._C._has_cpu() else None + ipex_xpu = ipex if ipex._C._has_xpu() else None +except: + ipex_cpu = None + ipex_xpu = None + + Tensor = torch.Tensor @@ -11,6 +21,22 @@ def _torch_version_prereq(major, minor): return ver_major * 32 + ver_minor >= major * 32 + minor +def _ipex_cpu_version_prereq(major, minor): + if ipex_cpu is not None: + ver_major = ipex_cpu.__version__.split(".")[0] + ver_minor = ipex_cpu.__version__.split(".")[1] + return int(ver_major) * 32 + int(ver_minor) >= major * 32 + minor + return False + + +def _ipex_xpu_version_prereq(major, minor): + if ipex_xpu is not None: + ver_major = ipex_xpu.__version__.split(".")[0] + ver_minor = ipex_xpu.__version__.split(".")[1] + return int(ver_major) * 32 + int(ver_minor) >= major * 32 + minor + return False + + def _maybe_torch_compile(func): # torch.compile requires pytorch >= 2.0 if _torch_version_prereq(2, 0): From 295bb973c301bfbb5d51aed5a2b79e955840296b Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 7 May 2024 02:40:01 -0700 Subject: [PATCH 11/12] Fix lint issue --- bitsandbytes/backends/cpu.py | 1 - bitsandbytes/backends/cpu_xpu_common.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py index 97e6580ed..d6a9192e4 100644 --- a/bitsandbytes/backends/cpu.py +++ b/bitsandbytes/backends/cpu.py @@ -5,7 +5,6 @@ from bitsandbytes.utils import QuantState from .base import Backend - from .cpu_xpu_common import ( double_quant_impl, igemmlt_impl, diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 5a0f0f9d5..ceac893b4 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -5,6 +5,7 @@ try: # to support Intel CPU/GPU (XPU) backend import intel_extension_for_pytorch as ipex + ipex_cpu = ipex if ipex._C._has_cpu() else None ipex_xpu = ipex if ipex._C._has_xpu() else None except: From 37b05821a5decab33c67527b9e365cbf0fbdf2f0 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 7 May 2024 06:34:32 -0700 Subject: [PATCH 12/12] Fix lint issue --- bitsandbytes/backends/cpu_xpu_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index ceac893b4..f4e5ed3ec 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -8,7 +8,7 @@ ipex_cpu = ipex if ipex._C._has_cpu() else None ipex_xpu = ipex if ipex._C._has_xpu() else None -except: +except BaseException: ipex_cpu = None ipex_xpu = None