Skip to content

Commit

Permalink
Enable double quant on Intel CPU and XPU (#1472)
Browse files Browse the repository at this point in the history
* fix dequant 8bit

Signed-off-by: jiqing-feng <[email protected]>

* support double quant on intel cpu and xpu

Signed-off-by: jiqing-feng <[email protected]>

* fix format

Signed-off-by: jiqing-feng <[email protected]>

* fix shape

Signed-off-by: jiqing-feng <[email protected]>

* fix 4bit format

Signed-off-by: jiqing-feng <[email protected]>

* fix device error for xpu

Signed-off-by: jiqing-feng <[email protected]>

* fix 4bit tensor shape

Signed-off-by: jiqing-feng <[email protected]>

* fix nf4 xpu finetune

Signed-off-by: jiqing-feng <[email protected]>

---------

Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng authored Jan 22, 2025
1 parent 7e6f865 commit f6025bc
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 26 deletions.
71 changes: 53 additions & 18 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import warnings

import torch
import torch.nn.functional as F

from bitsandbytes.functional import (
QuantState,
create_dynamic_map,
get_4bit_type,
)
from bitsandbytes.utils import reverse_4bit_compress_format

try:
# to support Intel CPU/GPU (XPU) backend
Expand Down Expand Up @@ -279,8 +282,9 @@ def mm_dequant_impl(
0.8333333: 3, # 0b0011
}

INT8_QUANT_TABLE = create_dynamic_map().tolist()


@_maybe_torch_compile
def quantize_4bit_impl(
A: Tensor,
absmax: Tensor = None,
Expand Down Expand Up @@ -314,7 +318,7 @@ def quantize_4bit_impl(
tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization.
"""
if quant_type not in ["nf4", "fp4"]:
if quant_type not in ["nf4", "fp4", "int8"]:
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.")
if quant_type == "fp4":
warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.")
Expand Down Expand Up @@ -355,14 +359,34 @@ def quantize_4bit_impl(
for key, val in FP4_QUANT_TABLE.items():
out_uint8[abs_scaled_A > key] = val
out_uint8 += sign.to(torch.uint8) * 8
if out_uint8.size(-1) % 2:
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2])
elif quant_type == "int8":
for i in range(len(INT8_QUANT_TABLE)):
out_uint8[scaled_A > INT8_QUANT_TABLE[i]] = i

code = get_4bit_type(quant_type, device=A.device)
if quant_type == "int8":
out = out_uint8
code = torch.Tensor(INT8_QUANT_TABLE).to(A.device)
else:
if out_uint8.size(-1) % 2:
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
out[:] = out_uint8[::2].bitwise_left_shift(4).bitwise_or_(out_uint8[1::2])
code = get_4bit_type(quant_type, device=A.device)

if compress_statistics:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
offset = absmax.mean()
absmax -= offset
qabsmax, state2 = quantize_4bit_impl(absmax, blocksize=256, quant_type="int8")
del absmax
state = QuantState(
absmax=qabsmax,
shape=input_shape,
dtype=A.dtype,
blocksize=blocksize,
code=code,
quant_type=quant_type,
offset=offset,
state2=state2,
)
else:
state = QuantState(
absmax=absmax,
Expand All @@ -373,7 +397,21 @@ def quantize_4bit_impl(
quant_type=quant_type,
)

return out.unsqueeze(0), state
return out.reshape(-1, 1), state


def dequant_8bit(A, offset, quant_state):
assert A.dtype == torch.uint8
absmax = quant_state.code[A.reshape(-1).int()]
blocks = absmax.shape[-1] // 256
res = absmax.shape[-1] % 256
if res != 0:
absmax = F.pad(absmax, (0, 256 - res), mode="constant", value=0)
absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(-1)
absmax = absmax[: blocks * 256 + res]
absmax = absmax.reshape(A.shape)
absmax += offset
return absmax


@_maybe_torch_compile
Expand Down Expand Up @@ -411,12 +449,8 @@ def dequantize_4bit_impl(
torch.Tensor:
Dequantized tensor.
"""
if A.shape[0] == 1:
transpose = False
A = A.squeeze(0)
elif A.shape[1] == 1:
transpose = True
A = A.squeeze(1)
transpose = True if A.shape[0] == 1 else False
A = A.reshape(-1)

if quant_state is None:
assert absmax is not None and out is not None
Expand All @@ -438,17 +472,18 @@ def dequantize_4bit_impl(
)

if quant_state.nested:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2)

if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
A = reverse_4bit_compress_format(ipex_weight)
quant_state.ipex = False

# Map nf4 to [-1, 1]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
n = out_dq.numel()
out_dq[::2] = A & 0xF
out_dq[1::2] = A >> 4
out_dq[1::2] = A & 0xF
out_dq[::2] = A >> 4
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
quant_state.code = quant_state.code.to(quant_state.dtype)
out_dq = quant_state.code[out_dq]
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/backends/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def dequantize_4bit(
if blocksize is None:
blocksize = 64
assert_on_xpu([A, absmax, out])
if quant_type == "nf4":
if quant_type == "nf4" and getattr(quant_state, "ipex", False):
output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t()
else:
output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)
Expand Down
5 changes: 3 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
enable_ipex_fusion,
reverse_4bit_compress_format,
)

T = TypeVar("T", bound="torch.nn.Module")
Expand Down Expand Up @@ -460,9 +461,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(
self.weight, "nf4", self.weight.quant_state.shape, 2
)
self.weight.data = original_weight.data
self.weight.data = reverse_4bit_compress_format(original_weight.data)
elif self.weight.device.type == "xpu":
self.weight.data = self.weight.data.reshape(1, -1)
self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1))

self.weight.quant_state.ipex = False

Expand Down
31 changes: 26 additions & 5 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,35 @@ def unpack_tensor_to_dict(tensor_data):
return unpacked_dict


def reverse_4bit_compress_format(weight):
out_1 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device)
out_2 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device)
out_1 = (weight & 0xF0) >> 4
out_2 = (weight & 0xF) << 4
out = out_1 | out_2
return out


def enable_ipex_fusion(linear, x):
from bitsandbytes.backends.cpu_xpu_common import (
_ipex_cpu_version_prereq,
_ipex_xpu_version_prereq,
dequant_8bit,
ipex_cpu,
ipex_xpu,
)

quant_state = linear.weight.quant_state

if quant_state.nested:
quant_state.absmax = dequant_8bit(quant_state.absmax, quant_state.offset, quant_state.state2)
quant_state.nested = False
delattr(quant_state, "state2")

if x.device.type == "cpu" and ipex_cpu and _ipex_cpu_version_prereq(2, 5):
quant_state = linear.weight.quant_state
converted_weight = reverse_4bit_compress_format(linear.weight.data)
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
"nf4",
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
Expand All @@ -222,12 +239,16 @@ def enable_ipex_fusion(linear, x):
2,
)
elif x.device.type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq(2, 5):
quant_state = linear.weight.quant_state
new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2])

converted_weight = reverse_4bit_compress_format(linear.weight.data)
new_weight = converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
new_zeros = None
compensation = None
else:
raise ValueError(
"Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.5"
)

linear.weight.data = new_weight.data
linear.weight.quant_state.ipex = True
linear.weight.quant_state.new_scales = new_scales
Expand Down

0 comments on commit f6025bc

Please sign in to comment.