Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add int8 ops for CPU #1178

7 changes: 6 additions & 1 deletion bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from . import research, utils
from .autograd._functions import (
MatmulLtState,
Expand All @@ -14,13 +15,17 @@
)
from .cextension import lib
from .nn import modules
from .backends import register_backend

from .backends.cpu import CPUBackend
register_backend("cpu", CPUBackend)

if lib and lib.compiled_with_cuda:
from .backends import register_backend
from .backends.cuda import CUDABackend
from .optim import adam

register_backend("cuda", CUDABackend())

__pdoc__ = {
"libbitsandbytes": False,
"optim.optimizer.Optimizer8bit": False,
Expand Down
13 changes: 9 additions & 4 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ def backward(ctx, grad_output):

def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if device == torch.device('cpu'):
return True
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
Expand Down Expand Up @@ -312,13 +314,16 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
state.outlier_pool = GlobalOutlierPooler.get_instance()

# Cast A to fp16
if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
A_dtype = torch.float16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A_dtype = torch.float16
if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
A_dtype = torch.float16

Tensors which are already in in fp16 do not need to be set again

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abhilash1910 Thanks for the comment. Here we are considering other dtypes like bfloat16 for CPU.

A_dtype = torch.float16
if A.device == torch.device('cpu'):
    A_dtype = torch.bfloat16
if A.dtype != A_dtype:
    warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes correct but if tensor already in fp16 then no need to convert right? the condition only applies if bf16 or other precision applies, then it goes in the condition (logic remains same I think ). Let me know your thoughts. Looks ok eitherway.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conversion is done afterwards. Here is just to print a warning.
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(A_dtype), threshold=state.threshold)
And in fact, if tensor is already in A_dtype, no action will be taken.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question that might be related here. Do we need to consider any changes (e.g. fall back to fp32) for users with a CPU that does not have AVX512-BF16 or AMX? Or is that something handled by torch?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will fall back to fp32 automatically. It's handled by torch.

if A.device == torch.device('cpu'):
A_dtype = torch.bfloat16
if A.dtype != A_dtype:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization")

# 1. Quantize A
if len(A.shape) == 3:
A = A.reshape(-1, A.shape[-1])
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(A_dtype), threshold=state.threshold)

if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
Expand Down Expand Up @@ -393,7 +398,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
if using_igemmlt:
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
if bias is None or bias.dtype == torch.float16:
if bias is None or bias.dtype == A_dtype:
# we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
Expand Down
117 changes: 117 additions & 0 deletions bitsandbytes/backends/cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import torch
from .cpu_xpu_common import (
double_quant_impl,
igemmlt_impl,
mm_dequant_impl,
)


Tensor = torch.Tensor


def assert_on_cpu(tensors):
on_cpu = True
for t in tensors:
if t is None: continue # NULL pointers are fine
on_cpu &= (t.device.type == 'cpu')
if not on_cpu:
raise TypeError(
'All input tensors need to be on CPU, but found some tensors to not be on CPU:\n' \
f' {[(t.shape, t.device) if isinstance(t, Tensor) else None for t in tensors]}'
)
return on_cpu


class CPUBackend:
mm_dequant_compute_dtype = torch.bfloat16
mm_dequant_output_dtype = torch.bfloat16

@classmethod
def double_quant(
cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
):
assert_on_cpu([A, col_stats, row_stats, out_col, out_row])
return double_quant_impl(A, col_stats, row_stats, out_col, out_row, threshold)

@classmethod
def transform(cls, A, to_order=None, from_order='row', out=None, transpose=False, state=None, ld=None):
"""
Transform tensor A to to_order. It is originally designed for CUDA.
For CPU, it returns the original tensor if transpose=False.
Otherwise, it returns the transpose of A
"""
assert_on_cpu([A, out])
if transpose:
if out is not None:
out.copy_(A.T)
else:
out = A.T
else:
if out is not None:
out.copy_(A)
else:
out = A
return out, state

@classmethod
def igemmlt(cls, A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32):
assert_on_cpu([A, B])
return igemmlt_impl(A, B, SA, SB, out, Sout, dtype)

@classmethod
def mm_dequant(
cls,
A,
quant_state,
row_stats,
col_stats,
out=None,
new_row_stats=None,
new_col_stats=None,
bias=None
):
assert_on_cpu([A, row_stats, col_stats, out, bias])
return mm_dequant_impl(
A,
quant_state,
row_stats,
col_stats,
out,
new_row_stats,
new_col_stats,
bias,
cls.mm_dequant_compute_dtype,
cls.mm_dequant_output_dtype
)

@classmethod
def extract_outliers(cls, A, SA, idx):
"""
Extract columns of A by idx
"""
assert_on_cpu([A])
return A[:, idx].contiguous()

@classmethod
def quantize_4bit(
cls,
A: Tensor,
absmax: Tensor = None,
out: Tensor = None,
blocksize=64,
compress_statistics=False,
quant_type="fp4",
) -> Tensor:
assert False, "quantize_4bit not yet implemented for CPU backend"

@classmethod
def dequantize_4bit(
cls,
A: Tensor,
quant_state = None,
absmax: Tensor = None,
out: Tensor = None,
blocksize: int = 64,
quant_type="fp4",
) -> Tensor:
assert False, "dequantize_4bit not yet implemented for CPU backend"
206 changes: 206 additions & 0 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import torch
import warnings


Tensor = torch.Tensor


def _torch_version_prereq(major, minor):
ver_major = int(torch.__version__.split('.')[0])
ver_minor = int(torch.__version__.split('.')[1])
return ver_major * 32 + ver_minor >= major * 32 + minor


def _maybe_torch_compile(func):
# torch.compile requires pytorch >= 2.0
if _torch_version_prereq(2, 0):
options = {}
# fx_graph_cache requires pytorch >= 2.2
if _torch_version_prereq(2, 2):
options.update({"fx_graph_cache": True})
return torch.compile(func, dynamic=True, options=options)
return func


@_maybe_torch_compile
def double_quant_impl(
A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
):
"""
Find absolute max valus of each row/column of a tensor, and symmetrically quantize it to int8.
If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in
the original tensor and they are kept in COO format: (rows, cols, valus)
If threashold == 0.0, there are no outliers.
Args:
A The tensor to be analyzed and quantized.
col_stats Absolute max values of each column of A. If it is not None, use the values directly.
Otherwise, find the values.
row_stats Absolute max values of each row of A. If it is not None, use the values directly.
Otherwise, find the values.
out_col Output buffer for the result quantized per column if it is not None
out_row Output buffer for the result quantized per row if it is not None
threshold The threshold for finding outliers if it is > 0.0. Otherwise it has no effect.
Return:
A tuple of output quantized per row, output quantized per column, absolute max values of
each row of A, absolute max values of each column of A, outliers in COO format
"""
from ..functional import COOSparseTensor
cols = A.shape[-1]
if len(A.shape) == 3:
rows = A.shape[0] * A.shape[1]
else:
assert A.dim() == 2, f"double_quant: Input tensor should be 2d or 3d but got {A.dim()}d"
rows = A.shape[0]
A = A.reshape(rows, cols)

coo_tensor = None

def get_row_col_stats(A):
row_stats = torch.max(torch.abs(A), 1).values # absolute max of each row
col_stats = torch.max(torch.abs(A), 0).values # absolute max of each col
return row_stats, col_stats

def quant_to_int8(A, stats):
return torch.clamp(torch.round(A * (127.0 / stats)), -128, 127).to(torch.int8)

if threshold == 0.0:
if row_stats is None or col_stats is None:
row_stats, col_stats = get_row_col_stats(A)
else:
outlier_indices = torch.abs(A) >= threshold # find outliers
outlier_coord = outlier_indices.nonzero() # get outlier coordinates
outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor
outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor
outlier_values = A[outlier_indices] # outlier values for COO sparse tensor
coo_tensor = COOSparseTensor(
A.shape[0], A.shape[1], outlier_values.numel(), outlier_rows.int(), outlier_cols.int(), outlier_values
)
if row_stats is None or col_stats is None:
A[outlier_indices] = 0 # zero out outliers
row_stats, col_stats = get_row_col_stats(A)

quant_by_row = quant_to_int8(A, row_stats.unsqueeze(-1))
quant_by_col = quant_to_int8(A, col_stats.unsqueeze(0))

if coo_tensor is not None:
A[outlier_indices] = outlier_values # restore outliers for later use

if out_row is not None:
out_row.copy_(quant_by_row)
else:
out_row = quant_by_row
if out_col is not None:
out_col.copy_(quant_by_col)
else:
out_col = quant_by_col
# Return float stats to align with CUDA impl
return out_row, out_col, row_stats.float(), col_stats.float(), coo_tensor


def igemmlt_impl(
A, B, SA=None, SB=None, out=None, Sout=None, dtype=torch.int32
):
"""
Do GEMMM computation. Data type: int8 * int8 -> int32.
Args:
A Activation of linear, data type is int8
B Weight of linear, data type is int8
SA Not used for CPU/XPU
SB Not used for CPU/XPU
out Specified output tensor if it is not None
Sout Not used for CPU/XPU but returned as is
dtype Data type of output
Return:
A tuple of GEMM result in dtype and Sout
"""
assert A.dtype == torch.int8
assert B.dtype == torch.int8
if out is not None:
assert out.dtype == dtype

dimsA = A.ndim
dimsB = B.ndim
shapeA = A.shape
shapeB = B.shape
assert dimsA in [2, 3], 'Only two or three dimensional matrices are supported for argument A'
assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'

if dimsA == 2:
m = shapeA[0]
elif dimsA == 3:
m = shapeA[0] * shapeA[1]
n = shapeB[0]
k = shapeA[-1]
assert shapeA[-1] == shapeB[-1], f'Shapes of A and B do not match, got {shapeA} and {shapeB}'

# if the tensor is empty, return a transformed empty tensor with the right dimensions
if shapeA[0] == 0 and dimsA == 2:
return torch.empty((0, n), device=A.device, dtype=A.dtype)
elif shapeA[1] == 0 and dimsA == 3:
return torch.empty(tuple(shapeA[:2] + [n]), device=A.device, dtype=A.dtype)

A_reshaped = A.reshape(m, k)

# torch._int_mm is available on CPU since torch 2.4
if _torch_version_prereq(2, 4):
C = torch._int_mm(A_reshaped, B.T).to(dtype)
else:
C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype)
if C.ndim != dimsA:
assert dimsA == 3
shapeOut = (shapeA[0], m // shapeA[0], C.shape[-1])
C = C.reshape(shapeOut)
if out is not None:
out.copy_(C)
else:
out = C

return out, Sout


@_maybe_torch_compile
def mm_dequant_impl(
A,
quant_state,
row_stats,
col_stats,
out=None,
new_row_stats=None,
new_col_stats=None,
bias=None,
compute_dtype=torch.float32,
output_dtype=torch.float32
):
"""
Dequant and add bias
out = A_int32 * (abs_max_A * abs_max_B) / 127 * 127 + bias
Args:
A The output of int8 gemm, whose dtype is int32
quant_state Not used for CPU
row_stats Absolute max value of each row of input (A) of gemm
col_stats Absolute max value of each row of weight (B) of gemm
out Output buffer
new_row_stats Not used for CPU/XPU
new_col_stats Not used for CPU/XPU
bias Bias of linear
compute_dtype Data type for computation
output_dtype Data type for output
Return:
The result
"""
assert A.dtype == torch.int32
out_shape = A.shape
if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2])

if compute_dtype not in [torch.float32, torch.bfloat16]:
warnings.warn(f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use float instead")
compute_dtype = torch.float32
A_reshaped = A.reshape(out_shape).to(compute_dtype)
row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype)
col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype)
out = A_reshaped * row_stats * col_stats / (127 * 127)
if bias is not None:
out = out + bias.to(compute_dtype)
out = out.to(output_dtype)
return out
Loading