From 5f78858ebdfdbfe3d2e2ca3d62bbb49ab3030255 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 20 Jan 2025 12:53:55 +0000 Subject: [PATCH] fix 4bit format Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu_xpu_common.py | 19 ++++++++++++++----- bitsandbytes/nn/modules.py | 5 +++-- bitsandbytes/utils.py | 23 +++++++++++++++++++---- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 733a73410..dd4f9f41a 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -3,12 +3,14 @@ import warnings import torch +import torch.nn.functional as F from bitsandbytes.functional import ( QuantState, create_dynamic_map, get_4bit_type, ) +from bitsandbytes.utils import reverse_4bit_compress_format try: # to support Intel CPU/GPU (XPU) backend @@ -367,7 +369,7 @@ def quantize_4bit_impl( else: if out_uint8.size(-1) % 2: out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) - out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2]) + out[:] = out_uint8[::2].bitwise_left_shift(4).bitwise_or_(out_uint8[1::2]) code = get_4bit_type(quant_type, device=A.device) if compress_statistics: @@ -401,7 +403,13 @@ def quantize_4bit_impl( def dequant_8bit(A, offset, quant_state): assert A.dtype == torch.uint8 absmax = quant_state.code[A.reshape(-1).int()] - absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(A.shape) + blocks = absmax.shape[-1] // 256 + res = absmax.shape[-1] % 256 + if res != 0: + absmax = F.pad(absmax, (0, 256 - res), mode="constant", value=0) + absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(-1) + absmax = absmax[: blocks * 256 + res] + absmax = absmax.reshape(A.shape) absmax += offset return absmax @@ -471,14 +479,15 @@ def dequantize_4bit_impl( absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2) if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False): - A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2) + ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2) + A = reverse_4bit_compress_format(ipex_weight) quant_state.ipex = False # Map nf4 to [-1, 1] out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) n = out_dq.numel() - out_dq[::2] = A & 0xF - out_dq[1::2] = A >> 4 + out_dq[1::2] = A & 0xF + out_dq[::2] = A >> 4 # quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue quant_state.code = quant_state.code.to(quant_state.dtype) out_dq = quant_state.code[out_dq] diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ad5a7d443..2320ffd39 100755 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -20,6 +20,7 @@ LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer, enable_ipex_fusion, + reverse_4bit_compress_format, ) T = TypeVar("T", bound="torch.nn.Module") @@ -460,9 +461,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight( self.weight, "nf4", self.weight.quant_state.shape, 2 ) - self.weight.data = original_weight.data + self.weight.data = reverse_4bit_compress_format(original_weight.data) elif self.weight.device.type == "xpu": - self.weight.data = self.weight.data.reshape(1, -1) + self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1)) self.weight.quant_state.ipex = False diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 6200bf0bd..e3748685e 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -200,13 +200,22 @@ def unpack_tensor_to_dict(tensor_data): return unpacked_dict +def reverse_4bit_compress_format(weight): + out_1 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device) + out_2 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device) + out_1 = (weight & 0xF0) >> 4 + out_2 = (weight & 0xF) << 4 + out = out_1 | out_2 + return out + + def enable_ipex_fusion(linear, x): from bitsandbytes.backends.cpu_xpu_common import ( _ipex_cpu_version_prereq, _ipex_xpu_version_prereq, + dequant_8bit, ipex_cpu, ipex_xpu, - dequant_8bit, ) quant_state = linear.weight.quant_state @@ -217,8 +226,9 @@ def enable_ipex_fusion(linear, x): delattr(quant_state, "state2") if x.device.type == "cpu" and ipex_cpu and _ipex_cpu_version_prereq(2, 5): + converted_weight = reverse_4bit_compress_format(linear.weight.data) new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight( - linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), + converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), "nf4", quant_state.shape, # weight shape quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales @@ -229,11 +239,16 @@ def enable_ipex_fusion(linear, x): 2, ) elif x.device.type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq(2, 5): - new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]) - + converted_weight = reverse_4bit_compress_format(linear.weight.data) + new_weight = converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]) new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize) new_zeros = None compensation = None + else: + raise ValueError( + "Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.5" + ) + linear.weight.data = new_weight.data linear.weight.quant_state.ipex = True linear.weight.quant_state.new_scales = new_scales