From 09cc153dea939f23747bea622560c84b5a95183f Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 8 May 2024 02:10:49 -0700 Subject: [PATCH 1/6] Support NF4 on CPU backend --- bitsandbytes/autograd/_functions.py | 3 +- bitsandbytes/backends/cpu.py | 15 +- bitsandbytes/backends/cpu_xpu_common.py | 266 +++++++++++++++++++++++- bitsandbytes/nn/modules.py | 7 +- 4 files changed, 284 insertions(+), 7 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 7d570f28b..6dea211ff 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -572,7 +572,8 @@ def matmul_4bit( bias=None, ): assert quant_state is not None - if A.numel() == A.shape[-1] and A.requires_grad == False: + if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False: + # CPU backend does not require A to be a vector if A.shape[-1] % quant_state.blocksize != 0: warn( f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py index d6a9192e4..a5e123e62 100644 --- a/bitsandbytes/backends/cpu.py +++ b/bitsandbytes/backends/cpu.py @@ -9,6 +9,9 @@ double_quant_impl, igemmlt_impl, mm_dequant_impl, + quantize_4bit_impl, + dequantize_4bit_impl, + gemm_4bit_impl, ) Tensor = torch.Tensor @@ -132,7 +135,8 @@ def quantize_4bit( quant_type: Literal["fp4", "nf4"] = "fp4", quant_storage=torch.uint8, ) -> Tuple[torch.Tensor, QuantState]: - raise NotImplementedError("Not yet implemented for CPU backend") + assert_on_cpu([A, absmax, out]) + return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type) def dequantize_4bit( self, @@ -143,7 +147,8 @@ def dequantize_4bit( blocksize: int = 64, quant_type: Literal["fp4", "nf4"] = "fp4", ) -> torch.Tensor: - raise NotImplementedError("Not yet implemented for CPU backend") + assert_on_cpu([A, absmax, out]) + return dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) def gemv_4bit( self, @@ -154,7 +159,11 @@ def gemv_4bit( transposed_B=False, state: QuantState = None, ) -> torch.Tensor: - raise NotImplementedError("Not yet implemented for CPU backend") + assert_on_cpu([A, B, out]) + if state is None: + raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()") + + return gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state) def dequantize_blockwise( self, diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index f4e5ed3ec..078b81680 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -1,6 +1,12 @@ import warnings - import torch +from typing import Optional +from bitsandbytes.functional import ( + get_4bit_type, + quantize_blockwise, + dequantize_blockwise, + QuantState, +) try: # to support Intel CPU/GPU (XPU) backend @@ -228,3 +234,261 @@ def mm_dequant_impl( out = out + bias.to(compute_dtype) out = out.to(output_dtype) return out + + +NF4_QUANT_TABLE = [ + -1.0 - 1e-2, # 0b0000 + -0.8480964004993439, # 0b0001 + -0.6106329262256622, # 0b0010 + -0.4599952697753906, # 0b0011 + -0.33967943489551544, # 0b0100 + -0.23460740596055984, # 0b0101 + -0.13791173323988914, # 0b0110 + -0.045525018125772476, # 0b0111 + 0.03979014977812767, # 0b1000 + 0.1202552504837513, # 0b1001 + 0.2035212516784668, # 0b1010 + 0.2920137718319893, # 0b1011 + 0.3893125355243683, # 0b1100 + 0.5016634166240692, # 0b1101 + 0.6427869200706482, # 0b1110 + 0.8614784181118011, # 0b1111 +] + + +# It's faster not to use torch.compile +def quantize_4bit_impl( + A: Tensor, + absmax: Tensor = None, + out: Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", +) -> Tensor: + """ + Quantize tensor A in blocks of 4-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor (8-bit). + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype, int): + The quantization state to undo the quantization. + """ + if quant_type != "nf4": + raise NotImplementedError( + f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU." + ) + n = A.numel() + input_shape = A.shape + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + + if absmax is None: + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + + if out is None: + out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) + + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + rem = n % blocksize + has_rem = rem > 0 + + # Scale tensor to [-1, 1] + A_reshaped = A.reshape(n) + A_com = A_reshaped[:n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[:blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[:blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem:]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem:] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + # map [-1, 1] to nf4 + out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8) + for i in range(len(NF4_QUANT_TABLE)): + out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i + if out_uint8.size(-1) % 2: + out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) + out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2]) + + code = get_4bit_type(quant_type, device=A.device) + + if compress_statistics: + raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") + else: + state = QuantState( + absmax=absmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + + if ipex_cpu and _ipex_cpu_version_prereq(2, 2) and input_shape[0] % blocksize == 0: + state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( + out.reshape([input_shape[0], input_shape[1] // 2]), + ipex_cpu.quantization.WoqWeightDtype.NF4, + input_shape, # weight shape + absmax.view(input_shape[0], input_shape[1] // blocksize), # scales + None, # zero_points + None, # bias + None, # g_idx + None, # batch_size + blocksize, + int(ipex_cpu.quantization.WoqLowpMode.BF16), + -1, # act_quant_mode + ) + + return out, state + + +@_maybe_torch_compile +def dequantize_4bit_impl( + A: Tensor, + quant_state = None, + absmax: Tensor = None, + out: Tensor = None, + blocksize: int = 64, + quant_type="nf4", +) -> Tensor: + """ + Dequantizes FP4 blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input 8-bit tensor (packed 4-bit values). + quant_state : QuantState + object with quantisation stats, incl. absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now + + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + + if quant_state is None: + assert absmax is not None and out is not None + + quant_state = QuantState( + absmax=absmax, + shape=out.shape, + dtype=out.dtype, + blocksize=blocksize, + quant_type=quant_type, + ) + + else: + absmax = quant_state.absmax + + if quant_state.quant_type != "nf4": + raise NotImplementedError( + f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU." + ) + + if quant_state.nested: + raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") + + if out is None: + out = torch.empty( + quant_state.shape, dtype=quant_state.dtype, device=A.device + ) + + n = out.numel() + # Map nf4 to [-1, 1] + out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device) + out_uint8[::2] = A.bitwise_and(0xF) + out_uint8[1::2] = A.bitwise_right_shift(4) + out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype) + for i in range(len(quant_state.code)): + out_dq[out_uint8 == i] = quant_state.code[i] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + out_reshaped = out.reshape(-1) + out_reshaped[:n - rem] = (out_dq[:n - rem].view(-1, blocksize) * absmax[:blocks - has_rem].view(-1, 1)).reshape(-1) + if has_rem: + out_reshaped[n - rem:] = out_dq[n - rem:] * absmax[-1] + + # take transpose here because weight is transposed (again) for computation + return out.t() + + +# Do not need torch.compile here as we are calling torch/ipex kernel +def gemm_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + transposed_A=False, + transposed_B=False, + state: QuantState = None, +) -> torch.Tensor: + """ + Matrix-matrix multiplication with 4-bit quantization. + + Parameters + ---------- + A : torch.Tensor + The first input tensor. Usually the activation tensor. + B : torch.Tensor + The second input tensor. Usually the weight tensor. + out : torch.Tensor + The output tensor. + transposed_A : bool + Whether A is transposed + transposed_B : bool + Whether B is transposed + state : QuantState + Contains quantization info, such as blocksize and dtype + + Returns + ------- + torch.Tensor: + GEMM output tensor. + """ + if ipex_cpu and _ipex_cpu_version_prereq(2, 2) and hasattr(state, "op_context"): + assert state.op_context is not None + output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle()) + else: + dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize) + output = torch.matmul(A, dqB) + if out is not None: + out.copy_(output) + else: + out = output + return out diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 7e9ab8d05..d52cb4847 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -285,7 +285,7 @@ def from_prequantized( return self def _quantize(self, device): - w = self.data.contiguous().cuda(device) + w = self.data.contiguous().to(device) w_4bit, quant_state = bnb.functional.quantize_4bit( w, blocksize=self.blocksize, @@ -303,6 +303,9 @@ def _quantize(self, device): def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) + def cpu(self, non_blocking: bool = False): + return self.to(device="cpu", non_blocking=non_blocking) + @overload def to( self: T, @@ -320,7 +323,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type == "cuda" and not self.bnb_quantized: + if device is not None and device.type in ["cuda", "cpu"] and not self.bnb_quantized: return self._quantize(device) else: if self.quant_state is not None: From 177bd398b3235f586e9e2110b6ffe8288eef4f00 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 10 May 2024 00:22:04 -0700 Subject: [PATCH 2/6] Minor improvements --- bitsandbytes/backends/cpu.py | 1 + bitsandbytes/backends/cpu_xpu_common.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py index a5e123e62..80b6c241e 100644 --- a/bitsandbytes/backends/cpu.py +++ b/bitsandbytes/backends/cpu.py @@ -136,6 +136,7 @@ def quantize_4bit( quant_storage=torch.uint8, ) -> Tuple[torch.Tensor, QuantState]: assert_on_cpu([A, absmax, out]) + assert quant_storage == torch.uint8, "CPU backend only supports uint8 quant_storage" return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type) def dequantize_4bit( diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 078b81680..ab881c6dd 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -343,6 +343,8 @@ def quantize_4bit_impl( ) if ipex_cpu and _ipex_cpu_version_prereq(2, 2) and input_shape[0] % blocksize == 0: + # lowp_mode: lowest precision for computation + lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16 state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( out.reshape([input_shape[0], input_shape[1] // 2]), ipex_cpu.quantization.WoqWeightDtype.NF4, @@ -353,8 +355,8 @@ def quantize_4bit_impl( None, # g_idx None, # batch_size blocksize, - int(ipex_cpu.quantization.WoqLowpMode.BF16), - -1, # act_quant_mode + int(lowp_mode), + -1, # act_quant_mode. -1 means don't quant activation ) return out, state From 881b5fcd0bc77f747850f397a0bf02c288332c17 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 10 May 2024 22:34:32 -0700 Subject: [PATCH 3/6] Add fp4 support; add UT; fix lint issues --- bitsandbytes/backends/cpu.py | 4 +- bitsandbytes/backends/cpu_xpu_common.py | 109 ++++++++++++++---------- tests/test_functional.py | 50 ++++++++++- 3 files changed, 114 insertions(+), 49 deletions(-) diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py index 80b6c241e..2c3688251 100644 --- a/bitsandbytes/backends/cpu.py +++ b/bitsandbytes/backends/cpu.py @@ -6,12 +6,12 @@ from .base import Backend from .cpu_xpu_common import ( + dequantize_4bit_impl, double_quant_impl, + gemm_4bit_impl, igemmlt_impl, mm_dequant_impl, quantize_4bit_impl, - dequantize_4bit_impl, - gemm_4bit_impl, ) Tensor = torch.Tensor diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index ab881c6dd..8d87f7e2f 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -1,11 +1,11 @@ +from typing import Optional import warnings + import torch -from typing import Optional + from bitsandbytes.functional import ( - get_4bit_type, - quantize_blockwise, - dequantize_blockwise, QuantState, + get_4bit_type, ) try: @@ -237,25 +237,37 @@ def mm_dequant_impl( NF4_QUANT_TABLE = [ - -1.0 - 1e-2, # 0b0000 - -0.8480964004993439, # 0b0001 - -0.6106329262256622, # 0b0010 - -0.4599952697753906, # 0b0011 + -1.0 - 1e-2, # 0b0000 + -0.8480964004993439, # 0b0001 + -0.6106329262256622, # 0b0010 + -0.4599952697753906, # 0b0011 -0.33967943489551544, # 0b0100 -0.23460740596055984, # 0b0101 -0.13791173323988914, # 0b0110 - -0.045525018125772476, # 0b0111 - 0.03979014977812767, # 0b1000 - 0.1202552504837513, # 0b1001 - 0.2035212516784668, # 0b1010 - 0.2920137718319893, # 0b1011 - 0.3893125355243683, # 0b1100 - 0.5016634166240692, # 0b1101 - 0.6427869200706482, # 0b1110 - 0.8614784181118011, # 0b1111 + -0.045525018125772476, # 0b0111 + 0.03979014977812767, # 0b1000 + 0.1202552504837513, # 0b1001 + 0.2035212516784668, # 0b1010 + 0.2920137718319893, # 0b1011 + 0.3893125355243683, # 0b1100 + 0.5016634166240692, # 0b1101 + 0.6427869200706482, # 0b1110 + 0.8614784181118011, # 0b1111 ] +FP4_QUANT_TABLE = { + 0 - 1e-2: 0, # 0b0000 + 0.00260417: 1, # 0b0001 + 0.0859375: 6, # 0b0110 + 0.20833333: 7, # 0b0111 + 0.29166667: 4, # 0b0100 + 0.4166667: 5, # 0b0101 + 0.583333: 2, # 0b0010 + 0.8333333: 3, # 0b0011 +} + + # It's faster not to use torch.compile def quantize_4bit_impl( A: Tensor, @@ -290,10 +302,11 @@ def quantize_4bit_impl( tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ - if quant_type != "nf4": - raise NotImplementedError( - f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU." - ) + if quant_type not in ["nf4", "fp4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.") + if quant_type == "fp4": + warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.") + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] n = A.numel() input_shape = A.shape blocks = n // blocksize @@ -305,25 +318,31 @@ def quantize_4bit_impl( if out is None: out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] rem = n % blocksize has_rem = rem > 0 # Scale tensor to [-1, 1] A_reshaped = A.reshape(n) - A_com = A_reshaped[:n - rem] + A_com = A_reshaped[: n - rem] A_com_reshaped = A_com.reshape(n // blocksize, blocksize) - absmax[:blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] - scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[:blocks - has_rem].view(-1, 1)), -1, 1) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) scaled_A = scaled_A.reshape(-1) if has_rem: - absmax[-1] = torch.abs(A_reshaped[n - rem:]).max() - scaled_A_rem = torch.clamp(A_reshaped[n - rem:] * (1 / absmax[-1]), -1, 1) + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) - # map [-1, 1] to nf4 + # map [-1, 1] to nf4/fp4 out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8) - for i in range(len(NF4_QUANT_TABLE)): - out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i + if quant_type == "nf4": + for i in range(len(NF4_QUANT_TABLE)): + out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i + elif quant_type == "fp4": + sign = scaled_A < 0 + abs_scaled_A = torch.abs(scaled_A) + for key, val in FP4_QUANT_TABLE.items(): + out_uint8[abs_scaled_A > key] = val + out_uint8 += sign.to(torch.uint8) * 8 if out_uint8.size(-1) % 2: out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2]) @@ -342,21 +361,21 @@ def quantize_4bit_impl( quant_type=quant_type, ) - if ipex_cpu and _ipex_cpu_version_prereq(2, 2) and input_shape[0] % blocksize == 0: + if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4": # lowp_mode: lowest precision for computation lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16 state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( out.reshape([input_shape[0], input_shape[1] // 2]), ipex_cpu.quantization.WoqWeightDtype.NF4, - input_shape, # weight shape - absmax.view(input_shape[0], input_shape[1] // blocksize), # scales - None, # zero_points - None, # bias - None, # g_idx - None, # batch_size + input_shape, # weight shape + absmax.view(input_shape[0], input_shape[1] // blocksize), # scales + None, # zero_points + None, # bias + None, # g_idx + None, # batch_size blocksize, int(lowp_mode), - -1, # act_quant_mode. -1 means don't quant activation + -1, # act_quant_mode. -1 means don't quant activation ) return out, state @@ -365,7 +384,7 @@ def quantize_4bit_impl( @_maybe_torch_compile def dequantize_4bit_impl( A: Tensor, - quant_state = None, + quant_state=None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, @@ -412,7 +431,7 @@ def dequantize_4bit_impl( else: absmax = quant_state.absmax - if quant_state.quant_type != "nf4": + if quant_type not in ["nf4", "fp4"]: raise NotImplementedError( f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU." ) @@ -421,9 +440,7 @@ def dequantize_4bit_impl( raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") if out is None: - out = torch.empty( - quant_state.shape, dtype=quant_state.dtype, device=A.device - ) + out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) n = out.numel() # Map nf4 to [-1, 1] @@ -443,9 +460,11 @@ def dequantize_4bit_impl( rem = n % blocksize has_rem = rem > 0 out_reshaped = out.reshape(-1) - out_reshaped[:n - rem] = (out_dq[:n - rem].view(-1, blocksize) * absmax[:blocks - has_rem].view(-1, 1)).reshape(-1) + out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape( + -1 + ) if has_rem: - out_reshaped[n - rem:] = out_dq[n - rem:] * absmax[-1] + out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1] # take transpose here because weight is transposed (again) for computation return out.t() diff --git a/tests/test_functional.py b/tests/test_functional.py index 8e125f712..ea15f148a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2003,7 +2003,8 @@ def test_bench_dequantization(): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) -def test_4bit_quant(dtype, quant_type, blocksize): +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +def test_4bit_quant(dtype, quant_type, blocksize, device): vals = list(product([0, 1], repeat=4)) code = {} @@ -2027,9 +2028,11 @@ def test_4bit_quant(dtype, quant_type, blocksize): result = sign * exp * frac code[idx] = result - A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) + A1 = torch.randn(1024, 1024, device=device, dtype=dtype) qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) + if device == "cpu": + A2 = A2.t() err = (A1 - A2).abs().float() relerr = (err / (A1.abs().float() + 1e-8)).mean() @@ -2279,6 +2282,49 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): assert maxratio < 1.02 and maxratio > 0.98 +@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +def test_gemv_4bit_cpu(dtype, quant_type, kind): + """ + Test 4bit GEMV for CPU. It is simplified a lot from the cuda version, since + the CPU backend does not support double_quant or quant_storage other than uint8. + Also, the CPU backend has different numeric accuracy from that of CUDA + """ + for dim in [128, 256, 512, 1024]: + for i in range(10): + if kind == "fc1": + A = torch.randn(1, dim, dtype=dtype, device="cpu") + B = torch.randn(dim * 4, dim, dtype=dtype, device="cpu") / math.sqrt(dim) + elif kind == "fc2": + A = torch.randn(1, 4 * dim, dtype=dtype, device="cpu") + B = torch.randn(dim, 4 * dim, dtype=dtype, device="cpu") / math.sqrt(dim) + elif kind == "attn": + A = torch.randn(1, dim, dtype=dtype, device="cpu") + B = torch.randn(dim, dim, dtype=dtype, device="cpu") / math.sqrt(dim) + elif kind == "attn_packed": + A = torch.randn(1, dim, dtype=dtype, device="cpu") + B = torch.randn(dim * 3, dim, dtype=dtype, device="cpu") / math.sqrt(dim) + + qB, state = F.quantize_4bit( + B, + quant_type=quant_type, + compress_statistics=False, + quant_storage=torch.uint8, + ) + dqB = F.dequantize_4bit(qB, state) + C3 = torch.matmul(A, dqB) + C2 = F.gemv_4bit(A, qB.t(), state=state) + A.requires_grad = True + C1 = bnb.matmul_4bit(A, qB.t(), state) + + c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 + rtol = 1e-3 if dtype != torch.bfloat16 else 1e-2 + atol = 1e-2 if dtype != torch.bfloat16 else 5e-2 + assert_all_approx_close(C1, C2, rtol, atol, count=c) + assert_all_approx_close(C3, C2, rtol, atol, count=c) + + @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed(): n = 32 * 10 From dd15734709f131b4c1e3244ba28e632dbf5a3ed6 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 10 May 2024 23:57:25 -0700 Subject: [PATCH 4/6] Reduce memory usage --- bitsandbytes/backends/cpu_xpu_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 8d87f7e2f..426d07975 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -377,6 +377,7 @@ def quantize_4bit_impl( int(lowp_mode), -1, # act_quant_mode. -1 means don't quant activation ) + return torch.Tensor(), state return out, state From 85a01b00fc131a586dec8fec5d25d753a471006c Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sat, 11 May 2024 00:42:31 -0700 Subject: [PATCH 5/6] Fix UT --- bitsandbytes/backends/cpu_xpu_common.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 426d07975..7c35a85c3 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -440,6 +440,11 @@ def dequantize_4bit_impl( if quant_state.nested: raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") + if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(quant_state, "op_context"): + assert quant_state.op_context is not None + A = quant_state.op_context.to_public(quant_state.op_context.get_weight()) + A = A.reshape(-1) + if out is None: out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) @@ -503,7 +508,7 @@ def gemm_4bit_impl( torch.Tensor: GEMM output tensor. """ - if ipex_cpu and _ipex_cpu_version_prereq(2, 2) and hasattr(state, "op_context"): + if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"): assert state.op_context is not None output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle()) else: From 2c489f8dde8e5992af5aa0956e1a4cb9554b72eb Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sat, 11 May 2024 00:54:17 -0700 Subject: [PATCH 6/6] reduce memory usage for nf4 --- bitsandbytes/backends/cpu_xpu_common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 7c35a85c3..138ec72f5 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -377,6 +377,7 @@ def quantize_4bit_impl( int(lowp_mode), -1, # act_quant_mode. -1 means don't quant activation ) + state.absmax = torch.Tensor() return torch.Tensor(), state return out, state @@ -444,6 +445,7 @@ def dequantize_4bit_impl( assert quant_state.op_context is not None A = quant_state.op_context.to_public(quant_state.op_context.get_weight()) A = A.reshape(-1) + absmax = quant_state.op_context.get_scales().reshape(-1) if out is None: out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)