Skip to content

Commit

Permalink
Enable XPU and optimize cpu/xpu op (bitsandbytes-foundation#1418)
Browse files Browse the repository at this point in the history
* enable new ipex API

ipex weight is 4D so we cannot transpose

fix dequant

check require grad

* use ipex op in backward

* enable backward

* Multi backend refactor (#8)

* AMD: Clarify diagnostic messages; free up disk space for CI build

* Add build job for rocm

* Add rocm build script

* Copy shared obj file into output_dir

* upload build artifacts and enable wheels build

* Remove cuda build temporarily

* Add ROCm version to .so filename

* Add rocm_version to whls build

* Revert "Remove cuda build temporarily"

This reverts commit 1413c5f.

* Add rocm_version env var

* Remove thrush header files

* Print node info

* print cuda node info

* Revert "print cuda node info"

This reverts commit cdb209a.

* Revert "Print node info"

This reverts commit 7e9a65c.

* Add rocm arch to compile command

* Rename .so files to rocm

* Update default gpu arch

* Skip cpu based igemmlt int tests on ROCm

* Update Documentation

* Update upstream repo name

* Update docs

* Update string format

Co-authored-by: Aarni Koskela <akx@iki.fi>

* Remove pre-release option for torch install

* Update pytorch install path

Co-authored-by: Titus <9048635+Titus-von-Koeller@users.noreply.github.com>

* Add messages for Heuristics error

* Remove toolcache for disk space

* print disk usage

* Clean disk space for linux

* Fix for ubuntu

* Add sudo for apt clean

* Update clean up disk list

* remove disk usage print

* Add BNB_BACKEND variable

* Update diagnostic functions for ROCm

* Fix tuple error

* Fix library detection bug for recursive and symlink cases

* fix pre-commit errors

* Remove recursive path lib search

* Create function for runtime lib patterns

* Update logger format

Co-authored-by: Aarni Koskela <akx@iki.fi>

* Update error reporting

Co-authored-by: Aarni Koskela <akx@iki.fi>

* Remove commented code

Co-authored-by: Aarni Koskela <akx@iki.fi>

* Update error reporting

Co-authored-by: Aarni Koskela <akx@iki.fi>

* Update error reporting

* Create hip diagnostics functions

* Fix Typo

* Fix pre-commit checks

---------

Co-authored-by: Aarni Koskela <akx@iki.fi>
Co-authored-by: Titus <9048635+Titus-von-Koeller@users.noreply.github.com>

* check grad before using ipex (bitsandbytes-foundation#1358)

* Enable packaging for ROCm 6.2 (bitsandbytes-foundation#1367)

* Enable 6.2 build

* Update documentation for 6.2.0 pip install

* Update for VS2022 17.11 compatibility with CUDA < 12.4 (bitsandbytes-foundation#1341)

* Update for VS2022 17.11 compatibility with CUDA < 12.4

* Try again

* Enable continuous releases for multi-backend-refactor branch

* Update release workflow

* Publish continuous release for multi-backend

* continuous release: revert wheel renaming due to install err

* Revert "continuous release: revert wheel renaming due to install err"

This reverts commit 0a2b539.

* add dynamic tag-based versioning + git hash for dev vers

* docs: update w/ changes from `main`

* get tags for dynamic versioning

* fine-tune continuous release params

* reduce the pkg size + build times for the preview release

* refine docs for multi-backend alpha release (bitsandbytes-foundation#1380)

* refine docs for multi-backend alpha release

* docs: further tweaks to multi-backend alpha docs

* docs: further tweaks to multi-backend alpha docs

* docs: further tweaks to multi-backend alpha docs

* docs: add multi-backend feedback links

* docs: add request for contributions

* docs: small fixes

* docs: small fixes

* docs: add info about `main` continuous build

* docs: further tweaks to multi-backend alpha docs

* docs: further tweaks to multi-backend alpha docs

* docs: remove 2 obsolete lines

---------

Co-authored-by: pnunna93 <104791500+pnunna93@users.noreply.github.com>
Co-authored-by: Aarni Koskela <akx@iki.fi>
Co-authored-by: Titus <9048635+Titus-von-Koeller@users.noreply.github.com>
Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Revert "enable backward"

This reverts commit cd7bf21.

* Revert "use ipex op in backward"

This reverts commit b8df1aa.

* fix finetune

* check training

* fix gemv check

* reformat

* avoid double quant in backward if not needed

* Zh/xpu support (#9)

* Add xpu support

* Add xpu support for int8

* Add xpu dequant kernel support

* update code

* remove debug comments

* remove redundant comments

* Add xpu integration for woqlinear

* correct the comments

* Update cpu_xpu_common.py

---------

Co-authored-by: zhuhong61 <hong.zhu@intel.com>
Co-authored-by: zhuhong61 <95205772+zhuhong61@users.noreply.github.com>

* avoid import triton if CPU and XPU backend

* fix setup in docker without git config

* xpu do not support compile for now

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update xpu

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update 4bit compute dtype

* fix xpu int8 path

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* optimize 4bit dequant

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix xpu dequant

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add empty cache in each xpu op

* add nf4 dequant ipex kernel

* fix dequant 4bit op

* empty cache has negative effect on 4bit gemv

* fix xpu save

* fix save

* xpu use float16 default

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* rm empty cache as it cause slower perf

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix xpu save

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix 8bit int8 param device

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix 8bit int8 param device

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix 8bit int8 param device

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix 8bit int8 param device

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

* update readme for Intel CPU and XPU do not need make csrc codes

* fix format

* fix import

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: pnunna93 <104791500+pnunna93@users.noreply.github.com>
Co-authored-by: Aarni Koskela <akx@iki.fi>
Co-authored-by: Titus <9048635+Titus-von-Koeller@users.noreply.github.com>
Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
Co-authored-by: zhuhong61 <hong.zhu@intel.com>
Co-authored-by: zhuhong61 <95205772+zhuhong61@users.noreply.github.com>
7 people authored Nov 29, 2024
1 parent cd73601 commit b2ac423
Showing 10 changed files with 246 additions and 101 deletions.
8 changes: 6 additions & 2 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
@@ -17,11 +17,10 @@
matmul_cublas,
mm_cublas,
)
from .backends import register_backend
from .backends import backends, register_backend
from .backends.cpu import CPUBackend
from .backends.npu import NPUBackend
from .cextension import lib
from .nn import modules

features = {"multi_backend"}
supported_torch_devices = {
@@ -64,6 +63,11 @@
if hasattr(torch, "npu") and torch.npu.is_available():
register_backend("npu", NPUBackend())


# import module after decided backends
if backends:
from .nn import modules

# TODO: Other potential backends:
# XLA - Google TPU / PJRT runtime
# HPU - Habana / Intel Gaudi
17 changes: 13 additions & 4 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
@@ -221,7 +221,7 @@ def backward(ctx, grad_output):

def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if device == torch.device("cpu"):
if device == torch.device("cpu") or torch.device("xpu"):
return True
if torch.version.hip:
return False if BNB_HIP_VERSION < 601 else True
@@ -463,7 +463,9 @@ def backward(ctx, grad_output):
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()

Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = None, None, None, None, None
if req_gradB or (req_gradA and state.CBt):
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
@@ -575,8 +577,15 @@ def matmul_4bit(
bias=None,
):
assert quant_state is not None
if (A.numel() == A.shape[-1] or A.device.type == "cpu") and A.requires_grad == False:
# CPU backend does not require A to be a vector
if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
elif A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
70 changes: 42 additions & 28 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@

ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
ipex_cpu_only = ipex._C._has_cpu() and (not ipex._C._has_xpu())
except BaseException:
ipex_cpu = None
ipex_xpu = None
@@ -55,7 +56,7 @@ def _ipex_xpu_version_prereq(major, minor):

def _maybe_torch_compile(func):
# torch.compile requires g++ and pytorch >= 2.0
if gxx_available and _torch_version_prereq(2, 0):
if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu:
options = {}
# fx_graph_cache requires pytorch >= 2.2
if _torch_version_prereq(2, 2):
@@ -181,7 +182,7 @@ def igemmlt_impl(A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32)
A_reshaped = A.reshape(m, k)

# torch._int_mm is available on CPU since torch 2.4
if _torch_version_prereq(2, 4):
if _torch_version_prereq(2, 4) and A.device.type == "cpu":
C = torch._int_mm(A_reshaped, B.T).to(dtype)
else:
C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype)
@@ -233,8 +234,10 @@ def mm_dequant_impl(
out_shape = (out_shape[0] * out_shape[1], out_shape[2])

if compute_dtype not in [torch.float32, torch.bfloat16]:
warnings.warn(f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use float instead")
compute_dtype = torch.float32
warnings.warn(
f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use bfloat16 instead"
)
compute_dtype = torch.bfloat16
A_reshaped = A.reshape(out_shape).to(compute_dtype)
row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype)
col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype)
@@ -342,7 +345,7 @@ def quantize_4bit_impl(
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
# map [-1, 1] to nf4/fp4
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8)
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device)
if quant_type == "nf4":
for i in range(len(NF4_QUANT_TABLE)):
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
@@ -408,7 +411,6 @@ def dequantize_4bit_impl(
torch.Tensor:
Dequantized tensor.
"""

if A.shape[0] == 1:
transpose = False
A = A.squeeze(0)
@@ -438,23 +440,18 @@ def dequantize_4bit_impl(
if quant_state.nested:
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")

if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(quant_state, "op_context"):
assert quant_state.op_context is not None
A = quant_state.op_context.to_public(quant_state.op_context.get_weight())
A = A.reshape(-1)
absmax = quant_state.op_context.get_scales().reshape(-1)

if out is None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
quant_state.ipex = False

n = out.numel()
# Map nf4 to [-1, 1]
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
out_uint8[::2] = A.bitwise_and(0xF)
out_uint8[1::2] = A.bitwise_right_shift(4)
out_dq = torch.empty(out_uint8.shape).to(quant_state.dtype)
for i in range(len(quant_state.code)):
out_dq[out_uint8 == i] = quant_state.code[i]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
n = out_dq.numel()
out_dq[::2] = A & 0xF
out_dq[1::2] = A >> 4
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
quant_state.code = quant_state.code.to(quant_state.dtype)
out_dq = quant_state.code[out_dq]

# Apply scales
if out_dq.numel() != n:
@@ -464,12 +461,17 @@ def dequantize_4bit_impl(
blocks += 1 if n % blocksize > 0 else 0
rem = n % blocksize
has_rem = rem > 0
out_reshaped = out.reshape(-1)
out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(
-1
)

if has_rem:
if out is None:
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
out_reshaped = out.reshape(-1)
out_reshaped[: n - rem] = (
out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)
).reshape(-1)
out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1]
else:
out = (out_dq.view(-1, blocksize) * absmax.view(-1, 1)).reshape(quant_state.shape).to(quant_state.dtype)

# take transpose here because weight is transposed (again) for computation
if transpose:
@@ -510,9 +512,21 @@ def gemm_4bit_impl(
torch.Tensor:
GEMM output tensor.
"""
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"):
assert state.op_context is not None
output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle())
if getattr(state, "ipex", False):
output = torch.ops.torch_ipex.woq_linear(
A,
B,
"nf4",
state.shape,
state.new_scales,
state.new_zeros,
None,
None,
state.blocksize,
ipex_cpu.quantization.WoqLowpMode.BF16,
1,
state.compensation,
)
else:
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t()
output = torch.matmul(A, dqB.to(A.dtype))
95 changes: 87 additions & 8 deletions bitsandbytes/backends/xpu.py
Original file line number Diff line number Diff line change
@@ -5,9 +5,36 @@
from bitsandbytes.utils import QuantState

from .base import Backend
from .cpu_xpu_common import (
dequantize_4bit_impl,
double_quant_impl,
gemm_4bit_impl,
igemmlt_impl,
mm_dequant_impl,
quantize_4bit_impl,
)

Tensor = torch.Tensor


def assert_on_xpu(tensors):
on_xpu = True
for t in tensors:
if t is None:
continue # NULL pointers are fine
on_xpu &= t.device.type == "xpu"
if not on_xpu:
raise TypeError(
"All input tensors need to be on XPU, but found some tensors to not be on XPU:\n"
f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}"
)
return on_xpu


class XPUBackend(Backend):
mm_dequant_compute_dtype = torch.bfloat16
mm_dequant_output_dtype = torch.bfloat16

def double_quant(
self,
A: torch.Tensor,
@@ -17,7 +44,9 @@ def double_quant(
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
):
raise NotImplementedError
assert_on_xpu([A, col_stats, row_stats, out_col, out_row])
output = double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)
return output

def transform(
self,
@@ -29,7 +58,23 @@ def transform(
state: Optional[Tuple[torch.Size, str]] = None,
ld=None,
):
raise NotImplementedError
"""
Transform tensor A to to_order. It is originally designed for CUDA.
For XPU, it returns the original tensor if transpose=False.
Otherwise, it returns the transpose of A
"""
assert_on_xpu([A, out])
if transpose:
if out is not None:
out.copy_(A.T)
else:
out = A.T
else:
if out is not None:
out.copy_(A)
else:
out = A
return out, state

def igemmlt(
self,
@@ -41,7 +86,9 @@ def igemmlt(
Sout: Optional[Tuple[torch.Size, str]] = None,
dtype=torch.int32,
) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]:
raise NotImplementedError
assert_on_xpu([A, B])
output = igemmlt_impl(A, B, SA, SB, out, Sout, dtype)
return output

def mm_dequant(
self,
@@ -54,15 +101,30 @@ def mm_dequant(
new_col_stats: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
assert_on_xpu([A, row_stats, col_stats, out, bias])
output = mm_dequant_impl(
A,
quant_state,
row_stats,
col_stats,
out,
new_row_stats,
new_col_stats,
bias,
self.mm_dequant_compute_dtype,
self.mm_dequant_output_dtype,
)
return output

def extract_outliers(
self,
A: torch.Tensor,
SA: Tuple[torch.Size, str],
idx: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
assert_on_xpu([A])
output = A[:, idx].contiguous()
return output

def quantize_4bit(
self,
@@ -74,7 +136,12 @@ def quantize_4bit(
quant_type: Literal["fp4", "nf4"] = "fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError
if blocksize is None:
blocksize = 64
assert_on_xpu([A, absmax, out])
assert quant_storage == torch.uint8, "XPU backend only supports uint8 quant_storage"
output = quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
return output

def dequantize_4bit(
self,
@@ -85,7 +152,15 @@ def dequantize_4bit(
blocksize: int = 64,
quant_type: Literal["fp4", "nf4"] = "fp4",
) -> torch.Tensor:
raise NotImplementedError
if blocksize is None:
blocksize = 64
assert_on_xpu([A, absmax, out])
if quant_type == "nf4":
output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t()
else:
output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)

return output

def gemv_4bit(
self,
@@ -96,7 +171,11 @@ def gemv_4bit(
transposed_B=False,
state: QuantState = None,
) -> torch.Tensor:
raise NotImplementedError
assert_on_xpu([A, B, out])
if state is None:
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
output = gemm_4bit_impl(A, B, out, transposed_A, transposed_B, state)
return output

def dequantize_blockwise(
self,
22 changes: 10 additions & 12 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
@@ -1006,11 +1006,6 @@ def dequantize_fp4(
out: Optional[torch.Tensor] = None,
blocksize: Optional[int] = None,
) -> Tensor:
if blocksize is None:
# Some AMD GPUs have warpsize 64
# Set default blocksize to 128 (~warpsize 64 in kernel) for HIP
blocksize = 64 if not HIP_ENVIRONMENT else 128

return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")


@@ -1021,11 +1016,6 @@ def dequantize_nf4(
out: Optional[torch.Tensor] = None,
blocksize: Optional[int] = None,
) -> Tensor:
if blocksize is None:
# Some AMD GPUs have warpsize 64
# Set default blocksize to 128 (~warpsize 64 in kernel) for HIP
blocksize = 64 if not HIP_ENVIRONMENT else 128

return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")


@@ -1035,7 +1025,7 @@ def dequantize_4bit(
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: Optional[int] = None,
quant_type="fp4",
quant_type=None,
) -> Tensor:
"""
Dequantizes FP4 blockwise quantized values.
@@ -1064,6 +1054,14 @@ def dequantize_4bit(
Dequantized tensor.
"""
ensure_backend_is_available(A.device.type)
if quant_state is not None:
absmax = absmax or quant_state.absmax
quant_type = quant_type or quant_state.quant_type
blocksize = blocksize or quant_state.blocksize
if blocksize is None:
# Some AMD GPUs have warpsize 64
# Set default blocksize to 128 (~warpsize 64 in kernel) for HIP
blocksize = 64 if not HIP_ENVIRONMENT else 128
return backends[A.device.type].dequantize_4bit(
A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type
)
@@ -1800,7 +1798,7 @@ class COOSparseTensor:
def __init__(self, rows, cols, nnz, rowidx, colidx, values):
assert rowidx.dtype == torch.int32
assert colidx.dtype == torch.int32
if values.device == torch.device("cpu"):
if values.device == torch.device("cpu") or torch.device("xpu"):
assert values.dtype in [torch.bfloat16, torch.half, torch.float]
else:
assert values.dtype == torch.float16
16 changes: 10 additions & 6 deletions bitsandbytes/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from ..backends import backends
from .modules import (
Embedding,
Int8Params,
@@ -14,9 +15,12 @@
StableEmbedding,
SwitchBackLinearBnb,
)
from .triton_based_modules import (
StandardLinear,
SwitchBackLinear,
SwitchBackLinearGlobal,
SwitchBackLinearVectorwise,
)

# CPU and XPU backend do not need triton, and XPU so not support triton for now.
if "xpu" not in backends.keys() and len(backends.keys()) > 1:
from .triton_based_modules import (
StandardLinear,
SwitchBackLinear,
SwitchBackLinearGlobal,
SwitchBackLinearVectorwise,
)
66 changes: 47 additions & 19 deletions bitsandbytes/nn/modules.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -314,6 +314,9 @@ def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: b
def cpu(self, non_blocking: bool = False):
return self.to(device="cpu", non_blocking=non_blocking)

def xpu(self, non_blocking: bool = False):
return self.to(device="xpu", non_blocking=non_blocking)

@overload
def to(
self: T,
@@ -331,7 +334,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

if device is not None and device.type in ["cuda", "cpu"] and not self.bnb_quantized:
if device is not None and device.type in ["cuda", "cpu", "xpu"] and not self.bnb_quantized:
return self._quantize(device)
else:
if self.quant_state is not None:
@@ -417,6 +420,7 @@ def __init__(
# self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype
self.compute_type_is_set = False
self.ipex_linear_is_set = False
self.quant_state = None
self.quant_storage = quant_storage

@@ -445,35 +449,39 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
save weight and bias,
then fill state_dict with components of quant_state
"""
if (
getattr(self.weight, "quant_state", None) is not None
and getattr(self.weight.quant_state, "op_context", None) is not None
):
context = self.weight.quant_state.op_context
self.weight.data = context.to_public(context.get_weight()).reshape([1, -1])
if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "ipex", False):
if self.weight.device.type == "cpu":
original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(
self.weight, "nf4", self.weight.quant_state.shape, 2
)
self.weight.data = original_weight.data
elif self.weight.device.type == "xpu":
self.weight.data = self.weight.data.reshape(1, -1)

self.weight.quant_state.ipex = False

super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias

if getattr(self.weight, "quant_state", None) is not None:
if (
self.weight.quant_state.absmax.shape.numel() == 0
and getattr(self.weight.quant_state, "op_context", None) is not None
):
self.weight.quant_state.absmax = context.get_scales().reshape(-1)
delattr(self.weight.quant_state, "op_context")
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()

def forward(self, x: torch.Tensor):
# Check if ipex fusion can be used
def set_ipex_linear(self, x: torch.Tensor):
if (
x.device.type == "cpu"
and not hasattr(self.weight.quant_state, "op_context")
(x.device.type in ("cpu", "xpu"))
and not getattr(self.weight.quant_state, "ipex", False)
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
and self.weight.quant_state.quant_type == "nf4"
and not self.training
and x.requires_grad == False
):
enable_ipex_fusion(self.weight, self.weight.quant_state)
enable_ipex_fusion(self)

def forward(self, x: torch.Tensor):
# Check if ipex fusion can be used
if not self.ipex_linear_is_set:
self.set_ipex_linear(x)
self.ipex_linear_is_set = True

# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
@@ -633,7 +641,20 @@ def __deepcopy__(self, memo):

def cpu(self):
# we store the 8-bit rows-major weight
B = self.data.contiguous().bfloat16().cpu()
B = self.data.contiguous().to(torch.bfloat16).cpu()
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
if CBt is not None:
del CBt
if SCBt is not None:
del SCBt
self.data = CB
self.CB = CB
self.SCB = SCB
return self

def xpu(self):
# we store the 8-bit rows-major weight
B = self.data.contiguous().to(torch.float16).xpu()
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
if CBt is not None:
del CBt
@@ -669,6 +690,13 @@ def to(self, *args, **kwargs):
return self
else:
return self.cpu()
elif device.type == "xpu":
if self.data.dtype == torch.int8:
self.data = self.data.contiguous().xpu()
self.CB = self.data
return self
else:
return self.xpu()
else:
new_param = Int8Params(
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
41 changes: 26 additions & 15 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
@@ -200,28 +200,39 @@ def unpack_tensor_to_dict(tensor_data):
return unpacked_dict


def enable_ipex_fusion(weight, quant_state):
from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq

if _ipex_cpu_version_prereq(2, 3):
import intel_extension_for_pytorch as ipex

lowp_mode = ipex.quantization.WoqLowpMode.BF16
quant_state.op_context = torch.ops.ipex_prepack.weight_only_qlinear_prepack(
weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
ipex.quantization.WoqWeightDtype.NF4,
def enable_ipex_fusion(linear):
from bitsandbytes.backends.cpu_xpu_common import (
_ipex_cpu_version_prereq,
_ipex_xpu_version_prereq,
ipex_cpu_only,
ipex_xpu,
)

if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5):
quant_state = linear.weight.quant_state
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
"nf4",
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # g_idx
None, # batch_size
quant_state.blocksize,
int(lowp_mode),
-1, # act_quant_mode. -1 means don't quant activation
2,
)
quant_state.absmax = torch.Tensor()
weight.data = torch.empty([1, 0], dtype=torch.uint8)
elif ipex_xpu and _ipex_xpu_version_prereq(2, 5):
quant_state = linear.weight.quant_state
new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2])

new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
new_zeros = None
compensation = None
linear.weight.data = new_weight.data
linear.weight.quant_state.ipex = True
linear.weight.quant_state.new_scales = new_scales
linear.weight.quant_state.new_zeros = new_zeros
linear.weight.quant_state.compensation = compensation


class QuantState:
6 changes: 2 additions & 4 deletions docs/source/installation.mdx
Original file line number Diff line number Diff line change
@@ -208,8 +208,8 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-11.7
|-------------|------------------------|---------------------------|-------------------------|------------|
| **AMD ROCm** | 6.1+ | 3.10+ | minimum CDNA - `gfx90a`, RDNA - `gfx1100` | Alpha |
| **Apple Silicon (MPS)** | WIP | 3.10+ | M1/M2 chips | Planned |
| **Intel CPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel CPU | Alpha |
| **Intel GPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel GPU | Experimental |
| **Intel CPU** | v2.5.0+ (`ipex`) | 3.10+ | Intel CPU | Alpha |
| **Intel GPU** | v2.5.0+ (`ipex`) | 3.10+ | Intel GPU | Experimental |

For each supported backend, follow the respective instructions below:

@@ -336,8 +336,6 @@ The below commands are for Linux. For installing on Windows, please adapt the be
git clone --depth 1 -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/
pip install intel_extension_for_pytorch
pip install -r requirements-dev.txt
cmake -DCOMPUTE_BACKEND=cpu -S .
make
pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out)
```

6 changes: 3 additions & 3 deletions docs/source/non_cuda_backends.mdx
Original file line number Diff line number Diff line change
@@ -33,12 +33,12 @@ The following performance data is collected from Intel 4th Gen Xeon (SPR) platfo

| Data Type | BF16 | INT8 | NF4 | FP4 |
|---|---|---|---|---|
| Speed-Up (vs BF16) | 1.0x | 0.6x | 2.3x | 0.03x |
| Speed-Up (vs BF16) | 1.0x | 0.44x | 1.8x | 0.1x |
| Memory (GB) | 13.1 | 7.6 | 5.0 | 4.6 |

#### Fine-Tuning (CPU)

| Data Type | AMP BF16 | INT8 | NF4 | FP4 |
| Data Type | BF16 | INT8 | NF4 | FP4 |
|---|---|---|---|---|
| Speed-Up (vs AMP BF16) | 1.0x | 0.38x | 0.07x | 0.07x |
| Speed-Up (vs BF16) | 1.0x | 0.38x | 0.1x | 0.1x |
| Memory (GB) | 40 | 9 | 6.6 | 6.6 |

0 comments on commit b2ac423

Please sign in to comment.