Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add int8 ops for CPU #1178

16 changes: 12 additions & 4 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
:return: contiguous row-major tensor
"""
# CPU has no change on layout
if permuted_tensor.device.type == "cpu":
return permuted_tensor
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
Expand Down Expand Up @@ -217,6 +220,8 @@ def backward(ctx, grad_output):

def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if device == torch.device("cpu"):
return True
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
Expand Down Expand Up @@ -312,13 +317,16 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
state.outlier_pool = GlobalOutlierPooler.get_instance()

# Cast A to fp16
if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
A_dtype = torch.float16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A_dtype = torch.float16
if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
A_dtype = torch.float16

Tensors which are already in in fp16 do not need to be set again

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abhilash1910 Thanks for the comment. Here we are considering other dtypes like bfloat16 for CPU.

A_dtype = torch.float16
if A.device == torch.device('cpu'):
    A_dtype = torch.bfloat16
if A.dtype != A_dtype:
    warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes correct but if tensor already in fp16 then no need to convert right? the condition only applies if bf16 or other precision applies, then it goes in the condition (logic remains same I think ). Let me know your thoughts. Looks ok eitherway.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conversion is done afterwards. Here is just to print a warning.
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(A_dtype), threshold=state.threshold)
And in fact, if tensor is already in A_dtype, no action will be taken.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question that might be related here. Do we need to consider any changes (e.g. fall back to fp32) for users with a CPU that does not have AVX512-BF16 or AMX? Or is that something handled by torch?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will fall back to fp32 automatically. It's handled by torch.

if A.device == torch.device("cpu"):
A_dtype = torch.bfloat16
if A.dtype != A_dtype:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization")

# 1. Quantize A
if len(A.shape) == 3:
A = A.reshape(-1, A.shape[-1])
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(A_dtype), threshold=state.threshold)

if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
Expand Down Expand Up @@ -393,7 +401,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
if using_igemmlt:
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
if bias is None or bias.dtype == torch.float16:
if bias is None or bias.dtype == A_dtype:
# we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
Expand Down
82 changes: 70 additions & 12 deletions bitsandbytes/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,33 @@
from bitsandbytes.utils import QuantState

from .base import Backend
from .cpu_xpu_common import (
double_quant_impl,
igemmlt_impl,
mm_dequant_impl,
)

Tensor = torch.Tensor


def assert_on_cpu(tensors):
on_cpu = True
for t in tensors:
if t is None:
continue # NULL pointers are fine
on_cpu &= t.device.type == "cpu"
if not on_cpu:
raise TypeError(
"All input tensors need to be on CPU, but found some tensors to not be on CPU:\n"
f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}"
)
return on_cpu


class CPUBackend(Backend):
mm_dequant_compute_dtype = torch.bfloat16
mm_dequant_output_dtype = torch.bfloat16

def double_quant(
self,
A: torch.Tensor,
Expand All @@ -17,7 +41,8 @@ def double_quant(
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
):
raise NotImplementedError
assert_on_cpu([A, col_stats, row_stats, out_col, out_row])
return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)

def transform(
self,
Expand All @@ -29,7 +54,23 @@ def transform(
state: Optional[Tuple[torch.Size, str]] = None,
ld=None,
):
raise NotImplementedError
"""
Transform tensor A to to_order. It is originally designed for CUDA.
For CPU, it returns the original tensor if transpose=False.
Otherwise, it returns the transpose of A
"""
assert_on_cpu([A, out])
if transpose:
if out is not None:
out.copy_(A.T)
else:
out = A.T
else:
if out is not None:
out.copy_(A)
else:
out = A
return out, state

def igemmlt(
self,
Expand All @@ -41,7 +82,8 @@ def igemmlt(
Sout: Optional[Tuple[torch.Size, str]] = None,
dtype=torch.int32,
) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]:
raise NotImplementedError
assert_on_cpu([A, B])
return igemmlt_impl(A, B, SA, SB, out, Sout, dtype)

def mm_dequant(
self,
Expand All @@ -54,15 +96,31 @@ def mm_dequant(
new_col_stats: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
assert_on_cpu([A, row_stats, col_stats, out, bias])
return mm_dequant_impl(
A,
quant_state,
row_stats,
col_stats,
out,
new_row_stats,
new_col_stats,
bias,
self.mm_dequant_compute_dtype,
self.mm_dequant_output_dtype,
)

def extract_outliers(
self,
A: torch.Tensor,
SA: Tuple[torch.Size, str],
idx: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
"""
Extract columns of A by idx
"""
assert_on_cpu([A])
return A[:, idx].contiguous()

def quantize_4bit(
self,
Expand All @@ -74,7 +132,7 @@ def quantize_4bit(
quant_type: Literal["fp4", "nf4"] = "fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")

def dequantize_4bit(
self,
Expand All @@ -85,7 +143,7 @@ def dequantize_4bit(
blocksize: int = 64,
quant_type: Literal["fp4", "nf4"] = "fp4",
) -> torch.Tensor:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")

def gemv_4bit(
self,
Expand All @@ -96,7 +154,7 @@ def gemv_4bit(
transposed_B=False,
state: QuantState = None,
) -> torch.Tensor:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")

def dequantize_blockwise(
self,
Expand All @@ -108,7 +166,7 @@ def dequantize_blockwise(
blocksize: int = 4096,
nested=False,
) -> torch.Tensor:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")

def quantize_blockwise(
self,
Expand All @@ -119,7 +177,7 @@ def quantize_blockwise(
blocksize=4096,
nested=False,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")

def optimizer_update_8bit_blockwise(
self,
Expand All @@ -141,7 +199,7 @@ def optimizer_update_8bit_blockwise(
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")

def optimizer_update_32bit(
self,
Expand All @@ -161,4 +219,4 @@ def optimizer_update_32bit(
max_unorm: float = 0.0,
skip_zeros=False,
) -> None:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")
Loading
Loading