Skip to content

Commit

Permalink
Enable dequant+matmul 8bit path for Intel CPU and XPU (#1484)
Browse files Browse the repository at this point in the history
* new matmul8bit

Signed-off-by: jiqing-feng <[email protected]>

* fix cxb

Signed-off-by: jiqing-feng <[email protected]>

---------

Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng authored Jan 28, 2025
1 parent f6025bc commit 307fbd5
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,29 @@ def backward(ctx, grad_output):
return grad_A, grad_B, None, grad_bias, None


class MatMul8bitFp(torch.autograd.Function):
# For Intel CPU and XPU, the double quant has many unsafe operations which will breaks the finetune.
# We'd like to use dequant + matmul to run finetune currently.

@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
CB = B.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)).t()
output = torch.matmul(A, CB).to(A.dtype)
ctx.state = state
ctx.dtype_A = A.dtype
ctx.grad_shape = A.shape
return output

@staticmethod
def backward(ctx, grad_output):
state = ctx.state
B = state.CxB if state.CxB is not None else state.CB
CB = B.to(ctx.dtype_A).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)

return grad_A, None, None, None, None


def matmul(
A: torch.Tensor,
B: torch.Tensor,
Expand All @@ -574,6 +597,8 @@ def matmul(
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
if A.device.type in ("cpu", "xpu") and state.is_training:
return MatMul8bitFp.apply(A, B, out, bias, state)
return MatMul8bitLt.apply(A, B, out, bias, state)


Expand Down

0 comments on commit 307fbd5

Please sign in to comment.