diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 019a4f6ab..0dae37e8d 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import torch from . import research, utils from .autograd._functions import ( MatmulLtState, @@ -14,13 +15,22 @@ ) from .cextension import lib from .nn import modules +from .backends import register_backend if lib and lib.compiled_with_cuda: - from .backends import register_backend from .backends.cuda import CUDABackend from .optim import adam register_backend("cuda", CUDABackend()) + +elif torch.xpu.is_available(): + from .backends.xpu import XPUBackend + register_backend("xpu", XPUBackend) + +else: + from .backends.cpu import CPUBackend + register_backend("cpu", CPUBackend) + __pdoc__ = { "libbitsandbytes": False, "optim.optimizer.Optimizer8bit": False, diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e9821cd36..67b8b6b87 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -217,6 +217,8 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: """check if this device supports the optimized int8 kernel""" + if device == torch.device('cpu'): + return True if torch.cuda.get_device_capability(device=device) < (7, 5): return False device_name = torch.cuda.get_device_name(device=device) @@ -312,13 +314,16 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): state.outlier_pool = GlobalOutlierPooler.get_instance() # Cast A to fp16 - if A.dtype != torch.float16: - warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") + A_dtype = torch.float16 + if A.device == torch.device('cpu'): + A_dtype = torch.bfloat16 + if A.dtype != A_dtype: + warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization") # 1. Quantize A if len(A.shape) == 3: A = A.reshape(-1, A.shape[-1]) - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(A_dtype), threshold=state.threshold) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -393,7 +398,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): if using_igemmlt: C32A, SA = F.transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) - if bias is None or bias.dtype == torch.float16: + if bias is None or bias.dtype == A_dtype: # we apply the fused bias here output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) output = output.to(A.dtype) diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py new file mode 100644 index 000000000..31bb52945 --- /dev/null +++ b/bitsandbytes/backends/cpu.py @@ -0,0 +1,287 @@ +import torch + + +Tensor = torch.Tensor + + +def assert_on_cpu(tensors): + on_cpu = True + for t in tensors: + if t is None: continue # NULL pointers are fine + on_cpu &= (t.device.type == 'cpu') + if not on_cpu: + raise TypeError( + 'All input tensors need to be on CPU, but found some tensors to not be on CPU:\n' \ + f' {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}' + ) + return on_cpu + + +@torch.compile(dynamic=True, options={"fx_graph_cache": True}) +def double_quant_common( + A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 +): + """ + Find absolute max valus of each row/column of a tensor, and symmetrically quantize it to int8. + If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in + the original tensor and they are kept in COO format: (rows, cols, valus) + If threashold == 0.0, there are no outliers. + Args: + A The tensor to be analyzed and quantized. + col_stats Absolute max values of each column of A. If it is not None, use the values directly. + Otherwise, find the values. + row_stats Absolute max values of each row of A. If it is not None, use the values directly. + Otherwise, find the values. + out_col Output buffer for the result quantized per column if it is not None + out_row Output buffer for the result quantized per row if it is not None + threshold The threshold for finding outliers if it is > 0.0. Otherwise it has no effect. + Return: + A tuple of output quantized per row, output quantized per column, absolute max values of + each row of A, absolute max values of each column of A, outliers in COO format + """ + from ..functional import COOSparseTensor + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + assert A.dim() == 2, f"double_quant: Input tensor should be 2d or 3d but got {A.dim()}d" + rows = A.shape[0] + A = A.reshape(rows, cols) + + coo_tensor = None + + def get_row_col_stats(A): + row_stats = torch.max(torch.abs(A), 1).values # absolute max of each row + col_stats = torch.max(torch.abs(A), 0).values # absolute max of each col + return row_stats, col_stats + + def quant_to_int8(A, stats): + return torch.clamp(torch.round(A / stats * 127).to(torch.int8), -128, 127) + + if threshold == 0.0: + if row_stats is None or col_stats is None: + row_stats, col_stats = get_row_col_stats(A) + else: + outlier_indices = torch.abs(A) > threshold # find outliers + outlier_coord = outlier_indices.nonzero() # get outlier coordinates + outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor + outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor + outlier_values = A[outlier_indices] # outlier values for COO sparse tensor + coo_tensor = COOSparseTensor( + A.shape[0], A.shape[1], outlier_values.numel(), outlier_rows.int(), outlier_cols.int(), outlier_values + ) + if row_stats is None or col_stats is None: + A[outlier_indices] = 0 # zero out outliers + row_stats, col_stats = get_row_col_stats(A) + A[outlier_indices] = outlier_values # restore outliers for later use + + quant_by_row = quant_to_int8(A, row_stats.unsqueeze(-1)) + quant_by_col = quant_to_int8(A, col_stats.unsqueeze(0)) + if out_row is not None: + out_row.copy_(quant_by_row) + else: + out_row = quant_by_row + if out_col is not None: + out_col.copy_(quant_by_col) + else: + out_col = quant_by_col + return out_row, out_col, row_stats, col_stats, coo_tensor + + +def igemmlt_common( + A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32 +): + """ + Do GEMMM computation. Data type: int8 * int8 -> int32. + Args: + A Activation of linear, data type is int8 + B Weight of linear, data type is int8 + SA Not used for CPU/XPU + SB Not used for CPU/XPU + out Specified output tensor if it is not None + Sout Not used for CPU/XPU but returned as is + dtype Data type of output + Return: + A tuple of GEMM result in dtype and Sout + """ + assert A.dtype == torch.int8 + assert B.dtype == torch.int8 + if out is not None: + assert out.dtype == dtype + + dimsA = A.ndim + dimsB = B.ndim + shapeA = A.shape + shapeB = B.shape + assert dimsA in [2, 3], 'Only two or three dimensional matrices are supported for argument A' + assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' + + if dimsA == 2: + m = shapeA[0] + elif dimsA == 3: + m = shapeA[0] * shapeA[1] + n = shapeB[0] + k = shapeA[-1] + assert shapeA[-1] == shapeB[-1], f'Shapes of A and B do not match, got {shapeA} and {shapeB}' + shapeOut = (shapeA[0], shapeA[1], n) if dimsA == 3 else (m, n) + + # if the tensor is empty, return a transformed empty tensor with the right dimensions + if shapeA[0] == 0 and dimsA == 2: + return torch.empty((0, n), device=A.device, dtype=A.dtype) + elif shapeA[1] == 0 and dimsA == 3: + return torch.empty(tuple(shapeA[:2] + [n]), device=A.device, dtype=A.dtype) + + A_reshaped = A.reshape(m, k) + + if assert_on_cpu([A_reshaped, B]): + C = torch._int_mm(A_reshaped, B.T).to(dtype) + else: + C = torch.nn.functional.linear(A_reshaped, B).to(dtype) + if C.ndim != dimsA: + C = C.reshape(shapeOut) + if out is not None: + out.copy_(C) + else: + out = C + + return out, Sout + + +@torch.compile(dynamic=True, options={"fx_graph_cache": True}) +def mm_dequant_common( + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None, + compute_dtype=torch.float32, + output_dtype=torch.float32 +): + """ + Dequant and add bias + out = A_int32 * (scale_A, scale_B) / 127 * 127 + bias + Args: + A The output of int8 gemm, whose dtype is int32 + quant_state Not used for CPU + row_stats Absolute max value of each row of input (A) of gemm + col_stats Absolute max value of each row of weight (B) of gemm + out Output buffer + new_row_stats Not used for CPU/XPU + new_col_stats Not used for CPU/XPU + bias Bias of linear + compute_dtype Data type for computation + output_dtype Data type for output + Return: + The result + """ + assert A.dtype == torch.int32 + out_shape = A.shape + if len(out_shape) == 3: + out_shape = (out_shape[0] * out_shape[1], out_shape[2]) + + A_reshaped = A.reshape(out_shape).to(compute_dtype) + row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) + col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) + out = A_reshaped * row_stats * col_stats / (127 * 127) + if bias is not None: + out = out + bias.to(compute_dtype) + out = out.to(output_dtype) + return out + + +class CPUBackend: + mm_dequant_compute_dtype = torch.bfloat16 + mm_dequant_output_dtype = torch.bfloat16 + + @classmethod + def double_quant( + cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 + ): + assert_on_cpu([A, col_stats, row_stats, out_col, out_row]) + return double_quant_common(A, col_stats, row_stats, out_col, out_row) + + @classmethod + def transform(cls, A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None): + """ + Transform tensor A to to_order. It is originally designed for CUDA. + For CPU, it returns the original tensor if transpose=False. + Otherwise, it returns the transpose of A + """ + assert_on_cpu([A, out]) + if transpose: + if out is not None: + out.copy_(A.T) + else: + out = A.T + else: + if out is not None: + out.copy_(A) + else: + out = A + return out, state + + @classmethod + def igemmlt(cls, A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32): + assert_on_cpu([A, B]) + return igemmlt_common(A, B, SA, SB, out, Sout, dtype) + + @classmethod + def mm_dequant( + cls, + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None + ): + assert_on_cpu([A, row_stats, col_stats, out, bias]) + return mm_dequant_common( + A, + quant_state, + row_stats, + col_stats, + out, + new_row_stats, + new_col_stats, + bias, + cls.mm_dequant_compute_dtype, + cls.mm_dequant_output_dtype + ) + + @classmethod + def extract_outliers(cls, A, SA, idx): + """ + Extract columns of A by idx + """ + assert_on_cpu([A]) + return A[:, idx].contiguous() + + @classmethod + def quantize_4bit( + cls, + A: Tensor, + absmax: Tensor = None, + out: Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="fp4", + ) -> Tensor: + assert False, "quantize_4bit not yet implemented for CPU backend" + + @classmethod + def dequantize_4bit( + cls, + A: Tensor, + quant_state = None, + absmax: Tensor = None, + out: Tensor = None, + blocksize: int = 64, + quant_type="fp4", + ) -> Tensor: + assert False, "dequantize_4bit not yet implemented for CPU backend" diff --git a/bitsandbytes/backends/xpu.py b/bitsandbytes/backends/xpu.py new file mode 100644 index 000000000..9ee8a09dc --- /dev/null +++ b/bitsandbytes/backends/xpu.py @@ -0,0 +1,118 @@ +# For Intel GPU (xpu is the device name for Intel GPU in PyTorch) +import torch +from .cpu import ( + double_quant_common, + igemmlt_common, + mm_dequant_common, +) + +Tensor = torch.Tensor + +def assert_on_xpu(tensors): + on_xpu = True + for t in tensors: + if t is None: continue # NULL pointers are fine + on_xpu &= (t.device.type == 'xpu') + if not on_xpu: + raise TypeError( + 'All input tensors need to be on XPU, but found some tensors to not be on XPU:\n' \ + f' {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}' + ) + return on_xpu + + +class XPUBackend: + mm_dequant_compute_dtype = torch.half + mm_dequant_output_dtype = torch.half + + @classmethod + @torch.compile(dynamic=True, options={"fx_graph_cache": True}) + def double_quant( + cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 + ): + assert_on_xpu([A, col_stats, row_stats, out_col, out_row]) + return double_quant_common(A, col_stats, row_stats, out_col, out_row) + + @classmethod + def transform(cls, A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None): + """ + Transform tensor A to to_order. It is originally designed for CUDA. + For XPU, it returns the original tensor if transpose=False. + Otherwise, it returns the transpose of A + """ + assert_on_xpu([A, out]) + if transpose: + if out is not None: + out.copy_(A.T) + else: + out = A.T + else: + if out is not None: + out.copy_(A) + else: + out = A + return out, state + + @classmethod + def igemmlt(cls, A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32): + assert_on_xpu([A, B]) + return igemmlt_common(A, B, SA, SB, out, Sout, dtype) + + @classmethod + @torch.compile(dynamic=True, options={"fx_graph_cache": True}) + def mm_dequant( + cls, + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None + ): + assert_on_xpu([A, row_stats, col_stats, out, bias]) + return mm_dequant_common( + A, + quant_state, + row_stats, + col_stats, + out, + new_row_stats, + new_col_stats, + bias, + cls.mm_dequant_compute_dtype, + cls.mm_dequant_output_dtype + ) + + @classmethod + def extract_outliers(cls, A, SA, idx): + """ + Extract columns of A by idx + """ + assert_on_xpu([A]) + return A[:, idx].contiguous() + + @classmethod + def quantize_4bit( + cls, + A: Tensor, + absmax: Tensor = None, + out: Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="fp4", + ) -> Tensor: + assert False, "quantize_4bit not yet implemented for XPU backend" + + @classmethod + def dequantize_4bit( + cls, + A: Tensor, + quant_state = None, + absmax: Tensor = None, + out: Tensor = None, + blocksize: int = 64, + quant_type="fp4", + ) -> Tensor: + assert False, "dequantize_4bit not yet implemented for XPU backend" diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6bb02944d..baba76963 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1945,7 +1945,10 @@ class COOSparseTensor: def __init__(self, rows, cols, nnz, rowidx, colidx, values): assert rowidx.dtype == torch.int32 assert colidx.dtype == torch.int32 - assert values.dtype == torch.float16 + if values.device == torch.device('cpu'): + assert values.dtype in [torch.bfloat16, torch.float] + else: + assert values.dtype == torch.float16 assert values.numel() == nnz assert rowidx.numel() == nnz assert colidx.numel() == nnz diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ec14e5940..bcba8b3d2 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -581,6 +581,32 @@ def cuda(self, device): return self + def cpu(self): + # we store the 8-bit rows-major weight + B = self.data.contiguous().bfloat16().cpu() + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + if CBt is not None: + del CBt + if SCBt is not None: + del SCBt + self.data = CB + setattr(self, "CB", CB) + setattr(self, "SCB", SCB) + return self + + def xpu(self): + # we store the 8-bit rows-major weight + B = self.data.contiguous().half().cpu() + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + if CBt is not None: + del CBt + if SCBt is not None: + del SCBt + self.data = CB + setattr(self, "CB", CB) + setattr(self, "SCB", SCB) + return self + @overload def to( self: T, @@ -600,6 +626,18 @@ def to(self, *args, **kwargs): if device is not None and device.type == "cuda" and self.data.device.type == "cpu": return self.cuda(device) + elif ( + device is not None + and device.type == "xpu" + and self.data.dtype != torch.int8 + ): + return self.xpu() + elif ( + device is not None + and device.type == "cpu" + and self.data.dtype != torch.int8 + ): + return self.cpu() else: new_param = Int8Params( super().to(device=device, dtype=dtype, non_blocking=non_blocking), diff --git a/examples/int8_inference_huggingface_cpu.py b/examples/int8_inference_huggingface_cpu.py new file mode 100644 index 000000000..b41605893 --- /dev/null +++ b/examples/int8_inference_huggingface_cpu.py @@ -0,0 +1,32 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +import time + +MAX_NEW_TOKENS = 64 +model_id = "facebook/opt-1.3b" + +text = 'Hamburg is in which country?\n' +tokenizer = AutoTokenizer.from_pretrained(model_id) +input_ids = tokenizer(text, return_tensors="pt").input_ids + +print('Loading model {}...'.format(model_id)) +quantization_config = BitsAndBytesConfig(load_in_8bit=True) +model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map='auto', + quantization_config=quantization_config, + torch_dtype=torch.bfloat16 +) +print('model dtype = {}'.format(model.dtype)) + +with torch.no_grad(): + t0 = time.time() + generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) + latency = time.time() - t0 + result = "| latency: " + str(round(latency * 1000, 3)) + " ms |" + print('+' + '-' * (len(result) - 2) + '+') + print(result) + print('+' + '-' * (len(result) - 2) + '+') + +output = tokenizer.decode(generated_ids[0], skip_special_tokens=True) +print(f"output: {output}")