diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 6440ab1b5..9de5a8924 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -563,6 +563,29 @@ def backward(ctx, grad_output): return grad_A, grad_B, None, grad_bias, None +class MatMul8bitFp(torch.autograd.Function): + # For Intel CPU and XPU, the double quant has many unsafe operations which will breaks the finetune. + # We'd like to use dequant + matmul to run finetune currently. + + @staticmethod + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): + CB = B.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)).t() + output = torch.matmul(A, CB).to(A.dtype) + ctx.state = state + ctx.dtype_A = A.dtype + ctx.grad_shape = A.shape + return output + + @staticmethod + def backward(ctx, grad_output): + state = ctx.state + B = state.CxB if state.CxB is not None else state.CB + CB = B.to(ctx.dtype_A).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + + return grad_A, None, None, None, None + + def matmul( A: torch.Tensor, B: torch.Tensor, @@ -574,6 +597,8 @@ def matmul( state = state or MatmulLtState() if threshold > 0.0: state.threshold = threshold + if A.device.type in ("cpu", "xpu") and state.is_training: + return MatMul8bitFp.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state)