diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index e6ec1d204..dab8de48a 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -17,11 +17,10 @@ matmul_cublas, mm_cublas, ) -from .backends import register_backend +from .backends import backends, register_backend from .backends.cpu import CPUBackend from .backends.npu import NPUBackend from .cextension import lib -from .nn import modules features = {"multi_backend"} supported_torch_devices = { @@ -76,6 +75,11 @@ if hasattr(torch, "npu") and torch.npu.is_available(): register_backend("npu", NPUBackend()) + +# import module after decided backends +if backends: + from .nn import modules + # TODO: Other potential backends: # XLA - Google TPU / PJRT runtime # HPU - Habana / Intel Gaudi diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 59e26ad09..9765def05 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -221,7 +221,7 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: """check if this device supports the optimized int8 kernel""" - if device == torch.device("cpu"): + if device == torch.device("cpu") or torch.device("xpu"): return True if torch.version.hip: return False if BNB_HIP_VERSION < 601 else True @@ -463,7 +463,9 @@ def backward(ctx, grad_output): if len(grad_output.shape) == 3: grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = None, None, None, None, None + if req_gradB or (req_gradA and state.CBt): + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: CxAt, SAt = F.transform(CAt, formatB, transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) @@ -575,8 +577,15 @@ def matmul_4bit( bias=None, ): assert quant_state is not None - if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False: - # CPU backend does not require A to be a vector + if A.device.type in ("cpu", "xpu") and A.requires_grad == False: + if getattr(quant_state, "ipex", False): + out = F.gemv_4bit(A, B.t(), out, state=quant_state) + if bias is not None: + out += bias + return out + else: + return MatMul4Bit.apply(A, B, out, bias, quant_state) + elif A.numel() == A.shape[-1] and A.requires_grad == False: if A.shape[-1] % quant_state.blocksize != 0: warn( f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 01bafa332..85ad7b214 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -16,6 +16,7 @@ ipex_cpu = ipex if ipex._C._has_cpu() else None ipex_xpu = ipex if ipex._C._has_xpu() else None + ipex_cpu_only = ipex._C._has_cpu() and (not ipex._C._has_xpu()) except BaseException: ipex_cpu = None ipex_xpu = None @@ -56,7 +57,7 @@ def _ipex_xpu_version_prereq(major, minor): def _maybe_torch_compile(func): # torch.compile requires g++ and pytorch >= 2.0 - if gxx_available and _torch_version_prereq(2, 0) and os.getenv('PT_HPU_LAZY_MODE',1)==0: + if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu and os.getenv('PT_HPU_LAZY_MODE',1)==0: options = {} # fx_graph_cache requires pytorch >= 2.2 if _torch_version_prereq(2, 2): @@ -182,7 +183,7 @@ def igemmlt_impl(A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32) A_reshaped = A.reshape(m, k) # torch._int_mm is available on CPU since torch 2.4 - if _torch_version_prereq(2, 4): + if _torch_version_prereq(2, 4) and A.device.type == "cpu": C = torch._int_mm(A_reshaped, B.T).to(dtype) else: C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype) @@ -234,8 +235,10 @@ def mm_dequant_impl( out_shape = (out_shape[0] * out_shape[1], out_shape[2]) if compute_dtype not in [torch.float32, torch.bfloat16]: - warnings.warn(f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use float instead") - compute_dtype = torch.float32 + warnings.warn( + f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use bfloat16 instead" + ) + compute_dtype = torch.bfloat16 A_reshaped = A.reshape(out_shape).to(compute_dtype) row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) @@ -408,7 +411,6 @@ def dequantize_4bit_impl( torch.Tensor: Dequantized tensor. """ - if A.shape[0] == 1: transpose = False A = A.squeeze(0) @@ -438,23 +440,27 @@ def dequantize_4bit_impl( if quant_state.nested: raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") - if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(quant_state, "op_context"): - assert quant_state.op_context is not None - A = quant_state.op_context.to_public(quant_state.op_context.get_weight()) - A = A.reshape(-1) - absmax = quant_state.op_context.get_scales().reshape(-1) - - if out is None: - out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) + if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False): + A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2) + quant_state.ipex = False - n = out.numel() # Map nf4 to [-1, 1] +<<<<<<< HEAD out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device) out_uint8[::2] = A.bitwise_and(0xF) out_uint8[1::2] = A.bitwise_right_shift(4) out_dq = torch.empty(out_uint8.shape, dtype=quant_state.code.dtype, device= quant_state.code.device) for i in range(len(quant_state.code)): out_dq[out_uint8 == i] = quant_state.code[i] +======= + out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) + n = out_dq.numel() + out_dq[::2] = A & 0xF + out_dq[1::2] = A >> 4 + # quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue + quant_state.code = quant_state.code.to(quant_state.dtype) + out_dq = quant_state.code[out_dq] +>>>>>>> b2ac423 (Enable XPU and optimize cpu/xpu op (#1418)) # Apply scales if out_dq.numel() != n: @@ -464,12 +470,17 @@ def dequantize_4bit_impl( blocks += 1 if n % blocksize > 0 else 0 rem = n % blocksize has_rem = rem > 0 - out_reshaped = out.reshape(-1) - out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape( - -1 - ) + if has_rem: + if out is None: + out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) + out_reshaped = out.reshape(-1) + out_reshaped[: n - rem] = ( + out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1) + ).reshape(-1) out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1] + else: + out = (out_dq.view(-1, blocksize) * absmax.view(-1, 1)).reshape(quant_state.shape).to(quant_state.dtype) # take transpose here because weight is transposed (again) for computation if transpose: @@ -510,9 +521,21 @@ def gemm_4bit_impl( torch.Tensor: GEMM output tensor. """ - if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"): - assert state.op_context is not None - output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle()) + if getattr(state, "ipex", False): + output = torch.ops.torch_ipex.woq_linear( + A, + B, + "nf4", + state.shape, + state.new_scales, + state.new_zeros, + None, + None, + state.blocksize, + ipex_cpu.quantization.WoqLowpMode.BF16, + 1, + state.compensation, + ) else: dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t() output = torch.matmul(A, dqB.to(A.dtype)) diff --git a/bitsandbytes/backends/xpu.py b/bitsandbytes/backends/xpu.py index 3976c4d5a..bc13963e6 100644 --- a/bitsandbytes/backends/xpu.py +++ b/bitsandbytes/backends/xpu.py @@ -5,9 +5,36 @@ from bitsandbytes.utils import QuantState from .base import Backend +from .cpu_xpu_common import ( + dequantize_4bit_impl, + double_quant_impl, + gemm_4bit_impl, + igemmlt_impl, + mm_dequant_impl, + quantize_4bit_impl, +) + +Tensor = torch.Tensor + + +def assert_on_xpu(tensors): + on_xpu = True + for t in tensors: + if t is None: + continue # NULL pointers are fine + on_xpu &= t.device.type == "xpu" + if not on_xpu: + raise TypeError( + "All input tensors need to be on XPU, but found some tensors to not be on XPU:\n" + f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}" + ) + return on_xpu class XPUBackend(Backend): + mm_dequant_compute_dtype = torch.bfloat16 + mm_dequant_output_dtype = torch.bfloat16 + def double_quant( self, A: torch.Tensor, @@ -17,7 +44,9 @@ def double_quant( out_row: Optional[torch.Tensor] = None, threshold=0.0, ): - raise NotImplementedError + assert_on_xpu([A, col_stats, row_stats, out_col, out_row]) + output = double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold) + return output def transform( self, @@ -29,7 +58,23 @@ def transform( state: Optional[Tuple[torch.Size, str]] = None, ld=None, ): - raise NotImplementedError + """ + Transform tensor A to to_order. It is originally designed for CUDA. + For XPU, it returns the original tensor if transpose=False. + Otherwise, it returns the transpose of A + """ + assert_on_xpu([A, out]) + if transpose: + if out is not None: + out.copy_(A.T) + else: + out = A.T + else: + if out is not None: + out.copy_(A) + else: + out = A + return out, state def igemmlt( self, @@ -41,7 +86,9 @@ def igemmlt( Sout: Optional[Tuple[torch.Size, str]] = None, dtype=torch.int32, ) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]: - raise NotImplementedError + assert_on_xpu([A, B]) + output = igemmlt_impl(A, B, SA, SB, out, Sout, dtype) + return output def mm_dequant( self, @@ -54,7 +101,20 @@ def mm_dequant( new_col_stats: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - raise NotImplementedError + assert_on_xpu([A, row_stats, col_stats, out, bias]) + output = mm_dequant_impl( + A, + quant_state, + row_stats, + col_stats, + out, + new_row_stats, + new_col_stats, + bias, + self.mm_dequant_compute_dtype, + self.mm_dequant_output_dtype, + ) + return output def extract_outliers( self, @@ -62,7 +122,9 @@ def extract_outliers( SA: Tuple[torch.Size, str], idx: torch.Tensor, ) -> torch.Tensor: - raise NotImplementedError + assert_on_xpu([A]) + output = A[:, idx].contiguous() + return output def quantize_4bit( self, @@ -74,7 +136,12 @@ def quantize_4bit( quant_type: Literal["fp4", "nf4"] = "fp4", quant_storage=torch.uint8, ) -> Tuple[torch.Tensor, QuantState]: - raise NotImplementedError + if blocksize is None: + blocksize = 64 + assert_on_xpu([A, absmax, out]) + assert quant_storage == torch.uint8, "XPU backend only supports uint8 quant_storage" + output = quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type) + return output def dequantize_4bit( self, @@ -85,7 +152,15 @@ def dequantize_4bit( blocksize: int = 64, quant_type: Literal["fp4", "nf4"] = "fp4", ) -> torch.Tensor: - raise NotImplementedError + if blocksize is None: + blocksize = 64 + assert_on_xpu([A, absmax, out]) + if quant_type == "nf4": + output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t() + else: + output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) + + return output def gemv_4bit( self, @@ -96,7 +171,11 @@ def gemv_4bit( transposed_B=False, state: QuantState = None, ) -> torch.Tensor: - raise NotImplementedError + assert_on_xpu([A, B, out]) + if state is None: + raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()") + output = gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state) + return output def dequantize_blockwise( self, diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6cf64df28..3c730cb16 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1006,11 +1006,6 @@ def dequantize_fp4( out: Optional[torch.Tensor] = None, blocksize: Optional[int] = None, ) -> Tensor: - if blocksize is None: - # Some AMD GPUs have warpsize 64 - # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP - blocksize = 64 if not HIP_ENVIRONMENT else 128 - return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1021,11 +1016,6 @@ def dequantize_nf4( out: Optional[torch.Tensor] = None, blocksize: Optional[int] = None, ) -> Tensor: - if blocksize is None: - # Some AMD GPUs have warpsize 64 - # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP - blocksize = 64 if not HIP_ENVIRONMENT else 128 - return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1035,7 +1025,7 @@ def dequantize_4bit( absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: Optional[int] = None, - quant_type="fp4", + quant_type=None, ) -> Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -1064,6 +1054,14 @@ def dequantize_4bit( Dequantized tensor. """ ensure_backend_is_available(A.device.type) + if quant_state is not None: + absmax = absmax or quant_state.absmax + quant_type = quant_type or quant_state.quant_type + blocksize = blocksize or quant_state.blocksize + if blocksize is None: + # Some AMD GPUs have warpsize 64 + # Set default blocksize to 128 (~warpsize 64 in kernel) for HIP + blocksize = 64 if not HIP_ENVIRONMENT else 128 return backends[A.device.type].dequantize_4bit( A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type ) @@ -1800,7 +1798,7 @@ class COOSparseTensor: def __init__(self, rows, cols, nnz, rowidx, colidx, values): assert rowidx.dtype == torch.int32 assert colidx.dtype == torch.int32 - if values.device == torch.device("cpu"): + if values.device == torch.device("cpu") or torch.device("xpu"): assert values.dtype in [torch.bfloat16, torch.half, torch.float] else: assert values.dtype == torch.float16 diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 96f4359bf..35bee393e 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from ..backends import backends from .modules import ( Embedding, Int8Params, @@ -14,9 +15,12 @@ StableEmbedding, SwitchBackLinearBnb, ) -from .triton_based_modules import ( - StandardLinear, - SwitchBackLinear, - SwitchBackLinearGlobal, - SwitchBackLinearVectorwise, -) + +# CPU and XPU backend do not need triton, and XPU so not support triton for now. +if "xpu" not in backends.keys() and len(backends.keys()) > 1: + from .triton_based_modules import ( + StandardLinear, + SwitchBackLinear, + SwitchBackLinearGlobal, + SwitchBackLinearVectorwise, + ) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py old mode 100644 new mode 100755 index 9a430be5c..8e5fa63dc --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -314,6 +314,9 @@ def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: b def cpu(self, non_blocking: bool = False): return self.to(device="cpu", non_blocking=non_blocking) + def xpu(self, non_blocking: bool = False): + return self.to(device="xpu", non_blocking=non_blocking) + @overload def to( self: T, @@ -331,7 +334,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type in ["cuda", "cpu", "hpu"] and not self.bnb_quantized: + if device is not None and device.type in ["cuda", "cpu", "xpu", "hpu"] and not self.bnb_quantized: return self._quantize(device) else: if self.quant_state is not None: @@ -417,6 +420,7 @@ def __init__( # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype self.compute_type_is_set = False + self.ipex_linear_is_set = False self.quant_state = None self.quant_storage = quant_storage @@ -445,35 +449,39 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): save weight and bias, then fill state_dict with components of quant_state """ - if ( - getattr(self.weight, "quant_state", None) is not None - and getattr(self.weight.quant_state, "op_context", None) is not None - ): - context = self.weight.quant_state.op_context - self.weight.data = context.to_public(context.get_weight()).reshape([1, -1]) + if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "ipex", False): + if self.weight.device.type == "cpu": + original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight( + self.weight, "nf4", self.weight.quant_state.shape, 2 + ) + self.weight.data = original_weight.data + elif self.weight.device.type == "xpu": + self.weight.data = self.weight.data.reshape(1, -1) + + self.weight.quant_state.ipex = False super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias if getattr(self.weight, "quant_state", None) is not None: - if ( - self.weight.quant_state.absmax.shape.numel() == 0 - and getattr(self.weight.quant_state, "op_context", None) is not None - ): - self.weight.quant_state.absmax = context.get_scales().reshape(-1) - delattr(self.weight.quant_state, "op_context") for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() - def forward(self, x: torch.Tensor): - # Check if ipex fusion can be used + def set_ipex_linear(self, x: torch.Tensor): if ( - x.device.type == "cpu" - and not hasattr(self.weight.quant_state, "op_context") + (x.device.type in ("cpu", "xpu")) + and not getattr(self.weight.quant_state, "ipex", False) and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 and self.weight.quant_state.quant_type == "nf4" + and not self.training and x.requires_grad == False ): - enable_ipex_fusion(self.weight, self.weight.quant_state) + enable_ipex_fusion(self) + + def forward(self, x: torch.Tensor): + # Check if ipex fusion can be used + if not self.ipex_linear_is_set: + self.set_ipex_linear(x) + self.ipex_linear_is_set = True # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: @@ -633,7 +641,20 @@ def __deepcopy__(self, memo): def cpu(self): # we store the 8-bit rows-major weight - B = self.data.contiguous().bfloat16().cpu() + B = self.data.contiguous().to(torch.bfloat16).cpu() + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + if CBt is not None: + del CBt + if SCBt is not None: + del SCBt + self.data = CB + self.CB = CB + self.SCB = SCB + return self + + def xpu(self): + # we store the 8-bit rows-major weight + B = self.data.contiguous().to(torch.float16).xpu() CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) if CBt is not None: del CBt @@ -669,6 +690,13 @@ def to(self, *args, **kwargs): return self else: return self.cpu() + elif device.type == "xpu": + if self.data.dtype == torch.int8: + self.data = self.data.contiguous().xpu() + self.CB = self.data + return self + else: + return self.xpu() else: new_param = Int8Params( super().to(device=device, dtype=dtype, non_blocking=non_blocking), diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 9e52c915d..adb36279c 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -200,28 +200,39 @@ def unpack_tensor_to_dict(tensor_data): return unpacked_dict -def enable_ipex_fusion(weight, quant_state): - from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq - - if _ipex_cpu_version_prereq(2, 3): - import intel_extension_for_pytorch as ipex - - lowp_mode = ipex.quantization.WoqLowpMode.BF16 - quant_state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( - weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), - ipex.quantization.WoqWeightDtype.NF4, +def enable_ipex_fusion(linear): + from bitsandbytes.backends.cpu_xpu_common import ( + _ipex_cpu_version_prereq, + _ipex_xpu_version_prereq, + ipex_cpu_only, + ipex_xpu, + ) + + if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5): + quant_state = linear.weight.quant_state + new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight( + linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), + "nf4", quant_state.shape, # weight shape quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales None, # zero_points None, # bias - None, # g_idx None, # batch_size quant_state.blocksize, - int(lowp_mode), - -1, # act_quant_mode. -1 means don't quant activation + 2, ) - quant_state.absmax = torch.Tensor() - weight.data = torch.empty([1, 0], dtype=torch.uint8) + elif ipex_xpu and _ipex_xpu_version_prereq(2, 5): + quant_state = linear.weight.quant_state + new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]) + + new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize) + new_zeros = None + compensation = None + linear.weight.data = new_weight.data + linear.weight.quant_state.ipex = True + linear.weight.quant_state.new_scales = new_scales + linear.weight.quant_state.new_zeros = new_zeros + linear.weight.quant_state.compensation = compensation class QuantState: diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index d1acb2cd6..615dfd95e 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -208,8 +208,8 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-11.7 |-------------|------------------------|---------------------------|-------------------------|------------| | **AMD ROCm** | 6.1+ | 3.10+ | minimum CDNA - `gfx90a`, RDNA - `gfx1100` | Alpha | | **Apple Silicon (MPS)** | WIP | 3.10+ | M1/M2 chips | Planned | -| **Intel CPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel CPU | Alpha | -| **Intel GPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel GPU | Experimental | +| **Intel CPU** | v2.5.0+ (`ipex`) | 3.10+ | Intel CPU | Alpha | +| **Intel GPU** | v2.5.0+ (`ipex`) | 3.10+ | Intel GPU | Experimental | For each supported backend, follow the respective instructions below: @@ -336,8 +336,6 @@ The below commands are for Linux. For installing on Windows, please adapt the be git clone --depth 1 -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ pip install intel_extension_for_pytorch pip install -r requirements-dev.txt -cmake -DCOMPUTE_BACKEND=cpu -S . -make pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out) ``` diff --git a/docs/source/non_cuda_backends.mdx b/docs/source/non_cuda_backends.mdx index 728606b7b..4c429fb2d 100644 --- a/docs/source/non_cuda_backends.mdx +++ b/docs/source/non_cuda_backends.mdx @@ -33,12 +33,12 @@ The following performance data is collected from Intel 4th Gen Xeon (SPR) platfo | Data Type | BF16 | INT8 | NF4 | FP4 | |---|---|---|---|---| -| Speed-Up (vs BF16) | 1.0x | 0.6x | 2.3x | 0.03x | +| Speed-Up (vs BF16) | 1.0x | 0.44x | 1.8x | 0.1x | | Memory (GB) | 13.1 | 7.6 | 5.0 | 4.6 | #### Fine-Tuning (CPU) -| Data Type | AMP BF16 | INT8 | NF4 | FP4 | +| Data Type | BF16 | INT8 | NF4 | FP4 | |---|---|---|---|---| -| Speed-Up (vs AMP BF16) | 1.0x | 0.38x | 0.07x | 0.07x | +| Speed-Up (vs BF16) | 1.0x | 0.38x | 0.1x | 0.1x | | Memory (GB) | 40 | 9 | 6.6 | 6.6 |