Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable XPU and optimize cpu/xpu op #1418

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
012c660
enable new ipex API
jiqing-feng Sep 14, 2024
b8df1aa
use ipex op in backward
jiqing-feng Sep 23, 2024
cd7bf21
enable backward
jiqing-feng Sep 27, 2024
5e19019
Multi backend refactor (#8)
jiqing-feng Oct 15, 2024
dd3b745
Revert "enable backward"
jiqing-feng Oct 15, 2024
8422f63
Revert "use ipex op in backward"
jiqing-feng Oct 15, 2024
9cbc081
fix finetune
jiqing-feng Oct 21, 2024
6860a4a
check training
jiqing-feng Oct 21, 2024
b2233b7
fix gemv check
jiqing-feng Oct 22, 2024
dbafcbb
reformat
jiqing-feng Oct 22, 2024
702b748
avoid double quant in backward if not needed
jiqing-feng Nov 8, 2024
1bde567
Zh/xpu support (#9)
jiqing-feng Nov 12, 2024
a5d92c3
avoid import triton if CPU and XPU backend
jiqing-feng Nov 13, 2024
e7b755b
fix setup in docker without git config
jiqing-feng Nov 13, 2024
4d4e240
xpu do not support compile for now
jiqing-feng Nov 13, 2024
0ccb0b5
update xpu
jiqing-feng Nov 14, 2024
712f584
update 4bit compute dtype
jiqing-feng Nov 14, 2024
b58db74
fix xpu int8 path
jiqing-feng Nov 14, 2024
0c9015a
optimize 4bit dequant
jiqing-feng Nov 14, 2024
0e919dc
fix xpu dequant
jiqing-feng Nov 14, 2024
ee4fd00
add empty cache in each xpu op
jiqing-feng Nov 14, 2024
35b8c91
add nf4 dequant ipex kernel
jiqing-feng Nov 15, 2024
347524d
fix dequant 4bit op
jiqing-feng Nov 15, 2024
ed0e370
empty cache has negative effect on 4bit gemv
jiqing-feng Nov 15, 2024
11db860
fix xpu save
jiqing-feng Nov 15, 2024
92e8c87
fix save
jiqing-feng Nov 15, 2024
57e5492
rebase
jiqing-feng Nov 15, 2024
cf0a807
xpu use float16 default
jiqing-feng Nov 15, 2024
7a32842
rm empty cache as it cause slower perf
jiqing-feng Nov 15, 2024
e636f75
fix xpu save
jiqing-feng Nov 18, 2024
987423a
fix 8bit int8 param device
jiqing-feng Nov 18, 2024
9da03f1
fix 8bit int8 param device
jiqing-feng Nov 18, 2024
aa3b245
fix 8bit int8 param device
jiqing-feng Nov 18, 2024
1e27a22
fix 8bit int8 param device
jiqing-feng Nov 18, 2024
314f724
fix format
jiqing-feng Nov 19, 2024
95387c8
update readme for Intel CPU and XPU do not need make csrc codes
jiqing-feng Nov 20, 2024
f039cfe
fix format
jiqing-feng Nov 21, 2024
23281c8
Merge branch 'multi-backend-refactor' into new_ipex
jiqing-feng Nov 21, 2024
bde878a
fix import
jiqing-feng Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
matmul_cublas,
mm_cublas,
)
from .backends import register_backend
from .backends import backends, register_backend
from .backends.cpu import CPUBackend
from .backends.npu import NPUBackend
from .cextension import lib
from .nn import modules

features = {"multi_backend"}
supported_torch_devices = {
Expand Down Expand Up @@ -64,6 +63,11 @@
if hasattr(torch, "npu") and torch.npu.is_available():
register_backend("npu", NPUBackend())


# import module after decided backends
if backends:
from .nn import modules

# TODO: Other potential backends:
# XLA - Google TPU / PJRT runtime
# HPU - Habana / Intel Gaudi
Expand Down
17 changes: 13 additions & 4 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def backward(ctx, grad_output):

def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if device == torch.device("cpu"):
if device == torch.device("cpu") or torch.device("xpu"):
return True
if torch.version.hip:
return False if BNB_HIP_VERSION < 601 else True
Expand Down Expand Up @@ -463,7 +463,9 @@ def backward(ctx, grad_output):
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()

Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = None, None, None, None, None
if req_gradB or (req_gradA and state.CBt):
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
Expand Down Expand Up @@ -575,8 +577,15 @@ def matmul_4bit(
bias=None,
):
assert quant_state is not None
if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False:
# CPU backend does not require A to be a vector
if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
elif A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
Expand Down
70 changes: 42 additions & 28 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
ipex_cpu_only = ipex._C._has_cpu() and (not ipex._C._has_xpu())
except BaseException:
ipex_cpu = None
ipex_xpu = None
Expand Down Expand Up @@ -55,7 +56,7 @@ def _ipex_xpu_version_prereq(major, minor):

def _maybe_torch_compile(func):
# torch.compile requires g++ and pytorch >= 2.0
if gxx_available and _torch_version_prereq(2, 0):
if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu:
options = {}
# fx_graph_cache requires pytorch >= 2.2
if _torch_version_prereq(2, 2):
Expand Down Expand Up @@ -181,7 +182,7 @@ def igemmlt_impl(A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32)
A_reshaped = A.reshape(m, k)

# torch._int_mm is available on CPU since torch 2.4
if _torch_version_prereq(2, 4):
if _torch_version_prereq(2, 4) and A.device.type == "cpu":
C = torch._int_mm(A_reshaped, B.T).to(dtype)
else:
C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype)
Expand Down Expand Up @@ -233,8 +234,10 @@ def mm_dequant_impl(
out_shape = (out_shape[0] * out_shape[1], out_shape[2])

if compute_dtype not in [torch.float32, torch.bfloat16]:
warnings.warn(f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use float instead")
compute_dtype = torch.float32
warnings.warn(
f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use bfloat16 instead"
)
compute_dtype = torch.bfloat16
A_reshaped = A.reshape(out_shape).to(compute_dtype)
row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype)
col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype)
Expand Down Expand Up @@ -342,7 +345,7 @@ def quantize_4bit_impl(
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
# map [-1, 1] to nf4/fp4
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8)
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device)
if quant_type == "nf4":
for i in range(len(NF4_QUANT_TABLE)):
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
Expand Down Expand Up @@ -408,7 +411,6 @@ def dequantize_4bit_impl(
torch.Tensor:
Dequantized tensor.
"""

if A.shape[0] == 1:
transpose = False
A = A.squeeze(0)
Expand Down Expand Up @@ -438,23 +440,18 @@ def dequantize_4bit_impl(
if quant_state.nested:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")

if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(quant_state, "op_context"):
assert quant_state.op_context is not None
A = quant_state.op_context.to_public(quant_state.op_context.get_weight())
A = A.reshape(-1)
absmax = quant_state.op_context.get_scales().reshape(-1)

if out is None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
quant_state.ipex = False

n = out.numel()
# Map nf4 to [-1, 1]
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
out_uint8[::2] = A.bitwise_and(0xF)
out_uint8[1::2] = A.bitwise_right_shift(4)
out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype)
for i in range(len(quant_state.code)):
out_dq[out_uint8 == i] = quant_state.code[i]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
n = out_dq.numel()
out_dq[::2] = A & 0xF
out_dq[1::2] = A >> 4
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
quant_state.code = quant_state.code.to(quant_state.dtype)
out_dq = quant_state.code[out_dq]

# Apply scales
if out_dq.numel() != n:
Expand All @@ -464,12 +461,17 @@ def dequantize_4bit_impl(
blocks += 1 if n % blocksize > 0 else 0
rem = n % blocksize
has_rem = rem > 0
out_reshaped = out.reshape(-1)
out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(
-1
)

if has_rem:
if out is None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
out_reshaped = out.reshape(-1)
out_reshaped[: n - rem] = (
out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)
).reshape(-1)
out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1]
else:
out = (out_dq.view(-1, blocksize) * absmax.view(-1, 1)).reshape(quant_state.shape).to(quant_state.dtype)

# take transpose here because weight is transposed (again) for computation
if transpose:
Expand Down Expand Up @@ -510,9 +512,21 @@ def gemm_4bit_impl(
torch.Tensor:
GEMM output tensor.
"""
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"):
assert state.op_context is not None
output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle())
if getattr(state, "ipex", False):
output = torch.ops.torch_ipex.woq_linear(
A,
B,
"nf4",
state.shape,
state.new_scales,
state.new_zeros,
None,
None,
state.blocksize,
ipex_cpu.quantization.WoqLowpMode.BF16,
1,
state.compensation,
)
else:
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t()
output = torch.matmul(A, dqB.to(A.dtype))
Expand Down
95 changes: 87 additions & 8 deletions bitsandbytes/backends/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,36 @@
from bitsandbytes.utils import QuantState

from .base import Backend
from .cpu_xpu_common import (
dequantize_4bit_impl,
double_quant_impl,
gemm_4bit_impl,
igemmlt_impl,
mm_dequant_impl,
quantize_4bit_impl,
)

Tensor = torch.Tensor


def assert_on_xpu(tensors):
on_xpu = True
for t in tensors:
if t is None:
continue # NULL pointers are fine
on_xpu &= t.device.type == "xpu"
if not on_xpu:
raise TypeError(
"All input tensors need to be on XPU, but found some tensors to not be on XPU:\n"
f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}"
)
return on_xpu


class XPUBackend(Backend):
mm_dequant_compute_dtype = torch.bfloat16
mm_dequant_output_dtype = torch.bfloat16

def double_quant(
self,
A: torch.Tensor,
Expand All @@ -17,7 +44,9 @@ def double_quant(
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
):
raise NotImplementedError
assert_on_xpu([A, col_stats, row_stats, out_col, out_row])
output = double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)
return output

def transform(
self,
Expand All @@ -29,7 +58,23 @@ def transform(
state: Optional[Tuple[torch.Size, str]] = None,
ld=None,
):
raise NotImplementedError
"""
Transform tensor A to to_order. It is originally designed for CUDA.
For XPU, it returns the original tensor if transpose=False.
Otherwise, it returns the transpose of A
"""
assert_on_xpu([A, out])
if transpose:
if out is not None:
out.copy_(A.T)
else:
out = A.T
else:
if out is not None:
out.copy_(A)
else:
out = A
return out, state

def igemmlt(
self,
Expand All @@ -41,7 +86,9 @@ def igemmlt(
Sout: Optional[Tuple[torch.Size, str]] = None,
dtype=torch.int32,
) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]:
raise NotImplementedError
assert_on_xpu([A, B])
output = igemmlt_impl(A, B, SA, SB, out, Sout, dtype)
return output

def mm_dequant(
self,
Expand All @@ -54,15 +101,30 @@ def mm_dequant(
new_col_stats: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
assert_on_xpu([A, row_stats, col_stats, out, bias])
output = mm_dequant_impl(
A,
quant_state,
row_stats,
col_stats,
out,
new_row_stats,
new_col_stats,
bias,
self.mm_dequant_compute_dtype,
self.mm_dequant_output_dtype,
)
return output

def extract_outliers(
self,
A: torch.Tensor,
SA: Tuple[torch.Size, str],
idx: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
assert_on_xpu([A])
output = A[:, idx].contiguous()
return output

def quantize_4bit(
self,
Expand All @@ -74,7 +136,12 @@ def quantize_4bit(
quant_type: Literal["fp4", "nf4"] = "fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError
if blocksize is None:
blocksize = 64
assert_on_xpu([A, absmax, out])
assert quant_storage == torch.uint8, "XPU backend only supports uint8 quant_storage"
output = quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
return output

def dequantize_4bit(
self,
Expand All @@ -85,7 +152,15 @@ def dequantize_4bit(
blocksize: int = 64,
quant_type: Literal["fp4", "nf4"] = "fp4",
) -> torch.Tensor:
raise NotImplementedError
if blocksize is None:
blocksize = 64
assert_on_xpu([A, absmax, out])
if quant_type == "nf4":
output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t()
else:
output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)

return output

def gemv_4bit(
self,
Expand All @@ -96,7 +171,11 @@ def gemv_4bit(
transposed_B=False,
state: QuantState = None,
) -> torch.Tensor:
raise NotImplementedError
assert_on_xpu([A, B, out])
if state is None:
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
output = gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)
return output

def dequantize_blockwise(
self,
Expand Down
Loading