Skip to content

Commit

Permalink
Merge pull request #1178 from Xia-Weiwen/multi-backend-refactor-cpu-x…
Browse files Browse the repository at this point in the history
…pu-ops

Add int8 ops for CPU
  • Loading branch information
Titus-von-Koeller authored May 7, 2024
2 parents 749e06f + 37b0582 commit 8561f09
Show file tree
Hide file tree
Showing 6 changed files with 423 additions and 36 deletions.
16 changes: 12 additions & 4 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
:return: contiguous row-major tensor
"""
# CPU has no change on layout
if permuted_tensor.device.type == "cpu":
return permuted_tensor
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
Expand Down Expand Up @@ -217,6 +220,8 @@ def backward(ctx, grad_output):

def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if device == torch.device("cpu"):
return True
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
Expand Down Expand Up @@ -312,13 +317,16 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
state.outlier_pool = GlobalOutlierPooler.get_instance()

# Cast A to fp16
if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
A_dtype = torch.float16
if A.device == torch.device("cpu"):
A_dtype = torch.bfloat16
if A.dtype != A_dtype:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization")

# 1. Quantize A
if len(A.shape) == 3:
A = A.reshape(-1, A.shape[-1])
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(A_dtype), threshold=state.threshold)

if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
Expand Down Expand Up @@ -393,7 +401,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
if using_igemmlt:
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
if bias is None or bias.dtype == torch.float16:
if bias is None or bias.dtype == A_dtype:
# we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
Expand Down
82 changes: 70 additions & 12 deletions bitsandbytes/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,33 @@
from bitsandbytes.utils import QuantState

from .base import Backend
from .cpu_xpu_common import (
double_quant_impl,
igemmlt_impl,
mm_dequant_impl,
)

Tensor = torch.Tensor


def assert_on_cpu(tensors):
on_cpu = True
for t in tensors:
if t is None:
continue # NULL pointers are fine
on_cpu &= t.device.type == "cpu"
if not on_cpu:
raise TypeError(
"All input tensors need to be on CPU, but found some tensors to not be on CPU:\n"
f" {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}"
)
return on_cpu


class CPUBackend(Backend):
mm_dequant_compute_dtype = torch.bfloat16
mm_dequant_output_dtype = torch.bfloat16

def double_quant(
self,
A: torch.Tensor,
Expand All @@ -17,7 +41,8 @@ def double_quant(
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
):
raise NotImplementedError
assert_on_cpu([A, col_stats, row_stats, out_col, out_row])
return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)

def transform(
self,
Expand All @@ -29,7 +54,23 @@ def transform(
state: Optional[Tuple[torch.Size, str]] = None,
ld=None,
):
raise NotImplementedError
"""
Transform tensor A to to_order. It is originally designed for CUDA.
For CPU, it returns the original tensor if transpose=False.
Otherwise, it returns the transpose of A
"""
assert_on_cpu([A, out])
if transpose:
if out is not None:
out.copy_(A.T)
else:
out = A.T
else:
if out is not None:
out.copy_(A)
else:
out = A
return out, state

def igemmlt(
self,
Expand All @@ -41,7 +82,8 @@ def igemmlt(
Sout: Optional[Tuple[torch.Size, str]] = None,
dtype=torch.int32,
) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]:
raise NotImplementedError
assert_on_cpu([A, B])
return igemmlt_impl(A, B, SA, SB, out, Sout, dtype)

def mm_dequant(
self,
Expand All @@ -54,15 +96,31 @@ def mm_dequant(
new_col_stats: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
assert_on_cpu([A, row_stats, col_stats, out, bias])
return mm_dequant_impl(
A,
quant_state,
row_stats,
col_stats,
out,
new_row_stats,
new_col_stats,
bias,
self.mm_dequant_compute_dtype,
self.mm_dequant_output_dtype,
)

def extract_outliers(
self,
A: torch.Tensor,
SA: Tuple[torch.Size, str],
idx: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
"""
Extract columns of A by idx
"""
assert_on_cpu([A])
return A[:, idx].contiguous()

def quantize_4bit(
self,
Expand All @@ -74,7 +132,7 @@ def quantize_4bit(
quant_type: Literal["fp4", "nf4"] = "fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")

def dequantize_4bit(
self,
Expand All @@ -85,7 +143,7 @@ def dequantize_4bit(
blocksize: int = 64,
quant_type: Literal["fp4", "nf4"] = "fp4",
) -> torch.Tensor:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")

def gemv_4bit(
self,
Expand All @@ -96,7 +154,7 @@ def gemv_4bit(
transposed_B=False,
state: QuantState = None,
) -> torch.Tensor:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")

def dequantize_blockwise(
self,
Expand All @@ -108,7 +166,7 @@ def dequantize_blockwise(
blocksize: int = 4096,
nested=False,
) -> torch.Tensor:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")

def quantize_blockwise(
self,
Expand All @@ -119,7 +177,7 @@ def quantize_blockwise(
blocksize=4096,
nested=False,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")

def optimizer_update_8bit_blockwise(
self,
Expand All @@ -141,7 +199,7 @@ def optimizer_update_8bit_blockwise(
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")

def optimizer_update_32bit(
self,
Expand All @@ -161,4 +219,4 @@ def optimizer_update_32bit(
max_unorm: float = 0.0,
skip_zeros=False,
) -> None:
raise NotImplementedError
raise NotImplementedError("Not yet implemented for CPU backend")
Loading

0 comments on commit 8561f09

Please sign in to comment.