Skip to content

Commit

Permalink
Add fp4 support; add UT; fix lint issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Xia-Weiwen committed May 11, 2024
1 parent 177bd39 commit 881b5fc
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 49 deletions.
4 changes: 2 additions & 2 deletions bitsandbytes/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

from .base import Backend
from .cpu_xpu_common import (
dequantize_4bit_impl,
double_quant_impl,
gemm_4bit_impl,
igemmlt_impl,
mm_dequant_impl,
quantize_4bit_impl,
dequantize_4bit_impl,
gemm_4bit_impl,
)

Tensor = torch.Tensor
Expand Down
109 changes: 64 additions & 45 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Optional
import warnings

import torch
from typing import Optional

from bitsandbytes.functional import (
get_4bit_type,
quantize_blockwise,
dequantize_blockwise,
QuantState,
get_4bit_type,
)

try:
Expand Down Expand Up @@ -237,25 +237,37 @@ def mm_dequant_impl(


NF4_QUANT_TABLE = [
-1.0 - 1e-2, # 0b0000
-0.8480964004993439, # 0b0001
-0.6106329262256622, # 0b0010
-0.4599952697753906, # 0b0011
-1.0 - 1e-2, # 0b0000
-0.8480964004993439, # 0b0001
-0.6106329262256622, # 0b0010
-0.4599952697753906, # 0b0011
-0.33967943489551544, # 0b0100
-0.23460740596055984, # 0b0101
-0.13791173323988914, # 0b0110
-0.045525018125772476, # 0b0111
0.03979014977812767, # 0b1000
0.1202552504837513, # 0b1001
0.2035212516784668, # 0b1010
0.2920137718319893, # 0b1011
0.3893125355243683, # 0b1100
0.5016634166240692, # 0b1101
0.6427869200706482, # 0b1110
0.8614784181118011, # 0b1111
-0.045525018125772476, # 0b0111
0.03979014977812767, # 0b1000
0.1202552504837513, # 0b1001
0.2035212516784668, # 0b1010
0.2920137718319893, # 0b1011
0.3893125355243683, # 0b1100
0.5016634166240692, # 0b1101
0.6427869200706482, # 0b1110
0.8614784181118011, # 0b1111
]


FP4_QUANT_TABLE = {
0 - 1e-2: 0, # 0b0000
0.00260417: 1, # 0b0001
0.0859375: 6, # 0b0110
0.20833333: 7, # 0b0111
0.29166667: 4, # 0b0100
0.4166667: 5, # 0b0101
0.583333: 2, # 0b0010
0.8333333: 3, # 0b0011
}


# It's faster not to use torch.compile
def quantize_4bit_impl(
A: Tensor,
Expand Down Expand Up @@ -290,10 +302,11 @@ def quantize_4bit_impl(
tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization.
"""
if quant_type != "nf4":
raise NotImplementedError(
f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU."
)
if quant_type not in ["nf4", "fp4"]:
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.")
if quant_type == "fp4":
warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.")
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
n = A.numel()
input_shape = A.shape
blocks = n // blocksize
Expand All @@ -305,25 +318,31 @@ def quantize_4bit_impl(
if out is None:
out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device)

assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
rem = n % blocksize
has_rem = rem > 0

# Scale tensor to [-1, 1]
A_reshaped = A.reshape(n)
A_com = A_reshaped[:n - rem]
A_com = A_reshaped[: n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
absmax[:blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[:blocks - has_rem].view(-1, 1)), -1, 1)
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
scaled_A = scaled_A.reshape(-1)
if has_rem:
absmax[-1] = torch.abs(A_reshaped[n - rem:]).max()
scaled_A_rem = torch.clamp(A_reshaped[n - rem:] * (1 / absmax[-1]), -1, 1)
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
# map [-1, 1] to nf4
# map [-1, 1] to nf4/fp4
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8)
for i in range(len(NF4_QUANT_TABLE)):
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
if quant_type == "nf4":
for i in range(len(NF4_QUANT_TABLE)):
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
elif quant_type == "fp4":
sign = scaled_A < 0
abs_scaled_A = torch.abs(scaled_A)
for key, val in FP4_QUANT_TABLE.items():
out_uint8[abs_scaled_A > key] = val
out_uint8 += sign.to(torch.uint8) * 8
if out_uint8.size(-1) % 2:
out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0)
out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2])
Expand All @@ -342,21 +361,21 @@ def quantize_4bit_impl(
quant_type=quant_type,
)

if ipex_cpu and _ipex_cpu_version_prereq(2, 2) and input_shape[0] % blocksize == 0:
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4":
# lowp_mode: lowest precision for computation
lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16
state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
out.reshape([input_shape[0], input_shape[1] // 2]),
ipex_cpu.quantization.WoqWeightDtype.NF4,
input_shape, # weight shape
absmax.view(input_shape[0], input_shape[1] // blocksize), # scales
None, # zero_points
None, # bias
None, # g_idx
None, # batch_size
input_shape, # weight shape
absmax.view(input_shape[0], input_shape[1] // blocksize), # scales
None, # zero_points
None, # bias
None, # g_idx
None, # batch_size
blocksize,
int(lowp_mode),
-1, # act_quant_mode. -1 means don't quant activation
-1, # act_quant_mode. -1 means don't quant activation
)

return out, state
Expand All @@ -365,7 +384,7 @@ def quantize_4bit_impl(
@_maybe_torch_compile
def dequantize_4bit_impl(
A: Tensor,
quant_state = None,
quant_state=None,
absmax: Tensor = None,
out: Tensor = None,
blocksize: int = 64,
Expand Down Expand Up @@ -412,7 +431,7 @@ def dequantize_4bit_impl(
else:
absmax = quant_state.absmax

if quant_state.quant_type != "nf4":
if quant_type not in ["nf4", "fp4"]:
raise NotImplementedError(
f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU."
)
Expand All @@ -421,9 +440,7 @@ def dequantize_4bit_impl(
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")

if out is None:
out = torch.empty(
quant_state.shape, dtype=quant_state.dtype, device=A.device
)
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)

n = out.numel()
# Map nf4 to [-1, 1]
Expand All @@ -443,9 +460,11 @@ def dequantize_4bit_impl(
rem = n % blocksize
has_rem = rem > 0
out_reshaped = out.reshape(-1)
out_reshaped[:n - rem] = (out_dq[:n - rem].view(-1, blocksize) * absmax[:blocks - has_rem].view(-1, 1)).reshape(-1)
out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(
-1
)
if has_rem:
out_reshaped[n - rem:] = out_dq[n - rem:] * absmax[-1]
out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1]

# take transpose here because weight is transposed (again) for computation
return out.t()
Expand Down
50 changes: 48 additions & 2 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2003,7 +2003,8 @@ def test_bench_dequantization():
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
def test_4bit_quant(dtype, quant_type, blocksize):
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_4bit_quant(dtype, quant_type, blocksize, device):
vals = list(product([0, 1], repeat=4))

code = {}
Expand All @@ -2027,9 +2028,11 @@ def test_4bit_quant(dtype, quant_type, blocksize):
result = sign * exp * frac
code[idx] = result

A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
A1 = torch.randn(1024, 1024, device=device, dtype=dtype)
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
if device == "cpu":
A2 = A2.t()

err = (A1 - A2).abs().float()
relerr = (err / (A1.abs().float() + 1e-8)).mean()
Expand Down Expand Up @@ -2279,6 +2282,49 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
assert maxratio < 1.02 and maxratio > 0.98


@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
def test_gemv_4bit_cpu(dtype, quant_type, kind):
"""
Test 4bit GEMV for CPU. It is simplified a lot from the cuda version, since
the CPU backend does not support double_quant or quant_storage other than uint8.
Also, the CPU backend has different numeric accuracy from that of CUDA
"""
for dim in [128, 256, 512, 1024]:
for i in range(10):
if kind == "fc1":
A = torch.randn(1, dim, dtype=dtype, device="cpu")
B = torch.randn(dim * 4, dim, dtype=dtype, device="cpu") / math.sqrt(dim)
elif kind == "fc2":
A = torch.randn(1, 4 * dim, dtype=dtype, device="cpu")
B = torch.randn(dim, 4 * dim, dtype=dtype, device="cpu") / math.sqrt(dim)
elif kind == "attn":
A = torch.randn(1, dim, dtype=dtype, device="cpu")
B = torch.randn(dim, dim, dtype=dtype, device="cpu") / math.sqrt(dim)
elif kind == "attn_packed":
A = torch.randn(1, dim, dtype=dtype, device="cpu")
B = torch.randn(dim * 3, dim, dtype=dtype, device="cpu") / math.sqrt(dim)

qB, state = F.quantize_4bit(
B,
quant_type=quant_type,
compress_statistics=False,
quant_storage=torch.uint8,
)
dqB = F.dequantize_4bit(qB, state)
C3 = torch.matmul(A, dqB)
C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)

c = int(C1.numel() * 0.0014 * (dim / 256)) + 1
rtol = 1e-3 if dtype != torch.bfloat16 else 1e-2
atol = 1e-2 if dtype != torch.bfloat16 else 5e-2
assert_all_approx_close(C1, C2, rtol, atol, count=c)
assert_all_approx_close(C3, C2, rtol, atol, count=c)


@pytest.mark.skip("Row scale has some bugs for ampere")
def test_managed():
n = 32 * 10
Expand Down

0 comments on commit 881b5fc

Please sign in to comment.