Skip to content

Commit

Permalink
fix 4bit format
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
jiqing-feng committed Jan 20, 2025
1 parent 96f4ac8 commit 5f78858
Showing 3 changed files with 36 additions and 11 deletions.
19 changes: 14 additions & 5 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
@@ -3,12 +3,14 @@
import warnings

import torch
import torch.nn.functional as F

from bitsandbytes.functional import (
QuantState,
create_dynamic_map,
get_4bit_type,
)
from bitsandbytes.utils import reverse_4bit_compress_format

try:
# to support Intel CPU/GPU (XPU) backend
@@ -367,7 +369,7 @@ def quantize_4bit_impl(
else:
if out_uint8.size(-1) % 2:
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2])
out[:] = out_uint8[::2].bitwise_left_shift(4).bitwise_or_(out_uint8[1::2])
code = get_4bit_type(quant_type, device=A.device)

if compress_statistics:
@@ -401,7 +403,13 @@ def quantize_4bit_impl(
def dequant_8bit(A, offset, quant_state):
assert A.dtype == torch.uint8
absmax = quant_state.code[A.reshape(-1).int()]
absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(A.shape)
blocks = absmax.shape[-1] // 256
res = absmax.shape[-1] % 256
if res != 0:
absmax = F.pad(absmax, (0, 256 - res), mode="constant", value=0)
absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(-1)
absmax = absmax[: blocks * 256 + res]
absmax = absmax.reshape(A.shape)
absmax += offset
return absmax

@@ -471,14 +479,15 @@ def dequantize_4bit_impl(
absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2)

if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
A = reverse_4bit_compress_format(ipex_weight)
quant_state.ipex = False

# Map nf4 to [-1, 1]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
n = out_dq.numel()
out_dq[::2] = A & 0xF
out_dq[1::2] = A >> 4
out_dq[1::2] = A & 0xF
out_dq[::2] = A >> 4
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
quant_state.code = quant_state.code.to(quant_state.dtype)
out_dq = quant_state.code[out_dq]
5 changes: 3 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
enable_ipex_fusion,
reverse_4bit_compress_format,
)

T = TypeVar("T", bound="torch.nn.Module")
@@ -460,9 +461,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(
self.weight, "nf4", self.weight.quant_state.shape, 2
)
self.weight.data = original_weight.data
self.weight.data = reverse_4bit_compress_format(original_weight.data)
elif self.weight.device.type == "xpu":
self.weight.data = self.weight.data.reshape(1, -1)
self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1))

self.weight.quant_state.ipex = False

23 changes: 19 additions & 4 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
@@ -200,13 +200,22 @@ def unpack_tensor_to_dict(tensor_data):
return unpacked_dict


def reverse_4bit_compress_format(weight):
out_1 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device)
out_2 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device)
out_1 = (weight & 0xF0) >> 4
out_2 = (weight & 0xF) << 4
out = out_1 | out_2
return out


def enable_ipex_fusion(linear, x):
from bitsandbytes.backends.cpu_xpu_common import (
_ipex_cpu_version_prereq,
_ipex_xpu_version_prereq,
dequant_8bit,
ipex_cpu,
ipex_xpu,
dequant_8bit,
)

quant_state = linear.weight.quant_state
@@ -217,8 +226,9 @@ def enable_ipex_fusion(linear, x):
delattr(quant_state, "state2")

if x.device.type == "cpu" and ipex_cpu and _ipex_cpu_version_prereq(2, 5):
converted_weight = reverse_4bit_compress_format(linear.weight.data)
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
"nf4",
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
@@ -229,11 +239,16 @@ def enable_ipex_fusion(linear, x):
2,
)
elif x.device.type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq(2, 5):
new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2])

converted_weight = reverse_4bit_compress_format(linear.weight.data)
new_weight = converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
new_zeros = None
compensation = None
else:
raise ValueError(
"Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.5"
)

linear.weight.data = new_weight.data
linear.weight.quant_state.ipex = True
linear.weight.quant_state.new_scales = new_scales

0 comments on commit 5f78858

Please sign in to comment.