From 881b5fcd0bc77f747850f397a0bf02c288332c17 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 10 May 2024 22:34:32 -0700 Subject: [PATCH] Add fp4 support; add UT; fix lint issues --- bitsandbytes/backends/cpu.py | 4 +- bitsandbytes/backends/cpu_xpu_common.py | 109 ++++++++++++++---------- tests/test_functional.py | 50 ++++++++++- 3 files changed, 114 insertions(+), 49 deletions(-) diff --git a/bitsandbytes/backends/cpu.py b/bitsandbytes/backends/cpu.py index 80b6c241e..2c3688251 100644 --- a/bitsandbytes/backends/cpu.py +++ b/bitsandbytes/backends/cpu.py @@ -6,12 +6,12 @@ from .base import Backend from .cpu_xpu_common import ( + dequantize_4bit_impl, double_quant_impl, + gemm_4bit_impl, igemmlt_impl, mm_dequant_impl, quantize_4bit_impl, - dequantize_4bit_impl, - gemm_4bit_impl, ) Tensor = torch.Tensor diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index ab881c6dd..8d87f7e2f 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -1,11 +1,11 @@ +from typing import Optional import warnings + import torch -from typing import Optional + from bitsandbytes.functional import ( - get_4bit_type, - quantize_blockwise, - dequantize_blockwise, QuantState, + get_4bit_type, ) try: @@ -237,25 +237,37 @@ def mm_dequant_impl( NF4_QUANT_TABLE = [ - -1.0 - 1e-2, # 0b0000 - -0.8480964004993439, # 0b0001 - -0.6106329262256622, # 0b0010 - -0.4599952697753906, # 0b0011 + -1.0 - 1e-2, # 0b0000 + -0.8480964004993439, # 0b0001 + -0.6106329262256622, # 0b0010 + -0.4599952697753906, # 0b0011 -0.33967943489551544, # 0b0100 -0.23460740596055984, # 0b0101 -0.13791173323988914, # 0b0110 - -0.045525018125772476, # 0b0111 - 0.03979014977812767, # 0b1000 - 0.1202552504837513, # 0b1001 - 0.2035212516784668, # 0b1010 - 0.2920137718319893, # 0b1011 - 0.3893125355243683, # 0b1100 - 0.5016634166240692, # 0b1101 - 0.6427869200706482, # 0b1110 - 0.8614784181118011, # 0b1111 + -0.045525018125772476, # 0b0111 + 0.03979014977812767, # 0b1000 + 0.1202552504837513, # 0b1001 + 0.2035212516784668, # 0b1010 + 0.2920137718319893, # 0b1011 + 0.3893125355243683, # 0b1100 + 0.5016634166240692, # 0b1101 + 0.6427869200706482, # 0b1110 + 0.8614784181118011, # 0b1111 ] +FP4_QUANT_TABLE = { + 0 - 1e-2: 0, # 0b0000 + 0.00260417: 1, # 0b0001 + 0.0859375: 6, # 0b0110 + 0.20833333: 7, # 0b0111 + 0.29166667: 4, # 0b0100 + 0.4166667: 5, # 0b0101 + 0.583333: 2, # 0b0010 + 0.8333333: 3, # 0b0011 +} + + # It's faster not to use torch.compile def quantize_4bit_impl( A: Tensor, @@ -290,10 +302,11 @@ def quantize_4bit_impl( tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ - if quant_type != "nf4": - raise NotImplementedError( - f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU." - ) + if quant_type not in ["nf4", "fp4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.") + if quant_type == "fp4": + warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.") + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] n = A.numel() input_shape = A.shape blocks = n // blocksize @@ -305,25 +318,31 @@ def quantize_4bit_impl( if out is None: out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] rem = n % blocksize has_rem = rem > 0 # Scale tensor to [-1, 1] A_reshaped = A.reshape(n) - A_com = A_reshaped[:n - rem] + A_com = A_reshaped[: n - rem] A_com_reshaped = A_com.reshape(n // blocksize, blocksize) - absmax[:blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] - scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[:blocks - has_rem].view(-1, 1)), -1, 1) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) scaled_A = scaled_A.reshape(-1) if has_rem: - absmax[-1] = torch.abs(A_reshaped[n - rem:]).max() - scaled_A_rem = torch.clamp(A_reshaped[n - rem:] * (1 / absmax[-1]), -1, 1) + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) - # map [-1, 1] to nf4 + # map [-1, 1] to nf4/fp4 out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8) - for i in range(len(NF4_QUANT_TABLE)): - out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i + if quant_type == "nf4": + for i in range(len(NF4_QUANT_TABLE)): + out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i + elif quant_type == "fp4": + sign = scaled_A < 0 + abs_scaled_A = torch.abs(scaled_A) + for key, val in FP4_QUANT_TABLE.items(): + out_uint8[abs_scaled_A > key] = val + out_uint8 += sign.to(torch.uint8) * 8 if out_uint8.size(-1) % 2: out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) out[:] = out_uint8[1::2].bitwise_left_shift(4).bitwise_or_(out_uint8[::2]) @@ -342,21 +361,21 @@ def quantize_4bit_impl( quant_type=quant_type, ) - if ipex_cpu and _ipex_cpu_version_prereq(2, 2) and input_shape[0] % blocksize == 0: + if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and input_shape[1] % blocksize == 0 and quant_type == "nf4": # lowp_mode: lowest precision for computation lowp_mode = ipex_cpu.quantization.WoqLowpMode.BF16 state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack( out.reshape([input_shape[0], input_shape[1] // 2]), ipex_cpu.quantization.WoqWeightDtype.NF4, - input_shape, # weight shape - absmax.view(input_shape[0], input_shape[1] // blocksize), # scales - None, # zero_points - None, # bias - None, # g_idx - None, # batch_size + input_shape, # weight shape + absmax.view(input_shape[0], input_shape[1] // blocksize), # scales + None, # zero_points + None, # bias + None, # g_idx + None, # batch_size blocksize, int(lowp_mode), - -1, # act_quant_mode. -1 means don't quant activation + -1, # act_quant_mode. -1 means don't quant activation ) return out, state @@ -365,7 +384,7 @@ def quantize_4bit_impl( @_maybe_torch_compile def dequantize_4bit_impl( A: Tensor, - quant_state = None, + quant_state=None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, @@ -412,7 +431,7 @@ def dequantize_4bit_impl( else: absmax = quant_state.absmax - if quant_state.quant_type != "nf4": + if quant_type not in ["nf4", "fp4"]: raise NotImplementedError( f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU." ) @@ -421,9 +440,7 @@ def dequantize_4bit_impl( raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU") if out is None: - out = torch.empty( - quant_state.shape, dtype=quant_state.dtype, device=A.device - ) + out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) n = out.numel() # Map nf4 to [-1, 1] @@ -443,9 +460,11 @@ def dequantize_4bit_impl( rem = n % blocksize has_rem = rem > 0 out_reshaped = out.reshape(-1) - out_reshaped[:n - rem] = (out_dq[:n - rem].view(-1, blocksize) * absmax[:blocks - has_rem].view(-1, 1)).reshape(-1) + out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape( + -1 + ) if has_rem: - out_reshaped[n - rem:] = out_dq[n - rem:] * absmax[-1] + out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1] # take transpose here because weight is transposed (again) for computation return out.t() diff --git a/tests/test_functional.py b/tests/test_functional.py index 8e125f712..ea15f148a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2003,7 +2003,8 @@ def test_bench_dequantization(): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) -def test_4bit_quant(dtype, quant_type, blocksize): +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +def test_4bit_quant(dtype, quant_type, blocksize, device): vals = list(product([0, 1], repeat=4)) code = {} @@ -2027,9 +2028,11 @@ def test_4bit_quant(dtype, quant_type, blocksize): result = sign * exp * frac code[idx] = result - A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) + A1 = torch.randn(1024, 1024, device=device, dtype=dtype) qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) + if device == "cpu": + A2 = A2.t() err = (A1 - A2).abs().float() relerr = (err / (A1.abs().float() + 1e-8)).mean() @@ -2279,6 +2282,49 @@ def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): assert maxratio < 1.02 and maxratio > 0.98 +@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +def test_gemv_4bit_cpu(dtype, quant_type, kind): + """ + Test 4bit GEMV for CPU. It is simplified a lot from the cuda version, since + the CPU backend does not support double_quant or quant_storage other than uint8. + Also, the CPU backend has different numeric accuracy from that of CUDA + """ + for dim in [128, 256, 512, 1024]: + for i in range(10): + if kind == "fc1": + A = torch.randn(1, dim, dtype=dtype, device="cpu") + B = torch.randn(dim * 4, dim, dtype=dtype, device="cpu") / math.sqrt(dim) + elif kind == "fc2": + A = torch.randn(1, 4 * dim, dtype=dtype, device="cpu") + B = torch.randn(dim, 4 * dim, dtype=dtype, device="cpu") / math.sqrt(dim) + elif kind == "attn": + A = torch.randn(1, dim, dtype=dtype, device="cpu") + B = torch.randn(dim, dim, dtype=dtype, device="cpu") / math.sqrt(dim) + elif kind == "attn_packed": + A = torch.randn(1, dim, dtype=dtype, device="cpu") + B = torch.randn(dim * 3, dim, dtype=dtype, device="cpu") / math.sqrt(dim) + + qB, state = F.quantize_4bit( + B, + quant_type=quant_type, + compress_statistics=False, + quant_storage=torch.uint8, + ) + dqB = F.dequantize_4bit(qB, state) + C3 = torch.matmul(A, dqB) + C2 = F.gemv_4bit(A, qB.t(), state=state) + A.requires_grad = True + C1 = bnb.matmul_4bit(A, qB.t(), state) + + c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 + rtol = 1e-3 if dtype != torch.bfloat16 else 1e-2 + atol = 1e-2 if dtype != torch.bfloat16 else 5e-2 + assert_all_approx_close(C1, C2, rtol, atol, count=c) + assert_all_approx_close(C3, C2, rtol, atol, count=c) + + @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed(): n = 32 * 10