Skip to content

Commit

Permalink
Merge pull request bitsandbytes-foundation#898 from jianan-gu/upstrea…
Browse files Browse the repository at this point in the history
…m_device_abstraction

Initial working draft to allow enable a device abstraction to enable different hardware backend.
  • Loading branch information
Titus-von-Koeller authored Apr 3, 2024
2 parents fd9d072 + 9ff6c63 commit 2ffa367
Show file tree
Hide file tree
Showing 7 changed files with 903 additions and 675 deletions.
8 changes: 7 additions & 1 deletion bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@
matmul_cublas,
mm_cublas,
)
from .cextension import lib
from .nn import modules
from .optim import adam

if lib and lib.compiled_with_cuda:
from .backends import register_backend
from .backends.cuda import CUDABackend
from .optim import adam

register_backend("cuda", CUDABackend())
__pdoc__ = {
"libbitsandbytes": False,
"optim.optimizer.Optimizer8bit": False,
Expand Down
15 changes: 15 additions & 0 deletions bitsandbytes/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Dict

from bitsandbytes.backends.base import Backend

backends: Dict[str, Backend] = {}


def register_backend(backend_name: str, backend_instance: Backend):
backends[backend_name.lower()] = backend_instance


def ensure_backend_is_available(device_type: str):
"""Check if a backend is available for the given device type."""
if device_type.lower() not in backends:
raise NotImplementedError(f"Device backend for {device_type} is currently not supported.")
133 changes: 133 additions & 0 deletions bitsandbytes/backends/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple

import torch

from bitsandbytes.utils import QuantState


class Backend(ABC):
"""Base class for devices backends that will implement their own 8bits and 4bits functions."""

@abstractmethod
def double_quant(
self,
A,
col_stats=None,
row_stats=None,
out_col=None,
out_row=None,
threshold=0.0,
):
raise NotImplementedError

@abstractmethod
def transform(
self,
A,
to_order,
from_order="row",
out=None,
transpose=False,
state=None,
ld=None,
):
raise NotImplementedError

@abstractmethod
def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
raise NotImplementedError

@abstractmethod
def mm_dequant(
self,
A,
quant_state,
row_stats,
col_stats,
out=None,
new_row_stats=None,
new_col_stats=None,
bias=None,
):
raise NotImplementedError

@abstractmethod
def extract_outliers(self, A, SA, idx):
raise NotImplementedError

@abstractmethod
def quantize_4bit(
self,
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
compress_statistics=False,
quant_type="fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
"""
Quantize tensor A in blocks of 4-bit values.
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
Parameters
----------
A : torch.Tensor
The input tensor.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
The output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Returns
-------
torch.Tensor:
Tensor with packed 4-bit values.
tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization.
"""
raise NotImplementedError

@abstractmethod
def dequantize_4bit(
self,
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
quant_type="fp4",
) -> torch.Tensor:
"""
Dequantizes FP4 blockwise quantized values.
Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.
Parameters
----------
A : torch.Tensor
The input tensor (packed 4-bit values).
quant_state : QuantState
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
Dequantized output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Returns
-------
torch.Tensor:
Dequantized tensor.
"""
raise NotImplementedError
Loading

0 comments on commit 2ffa367

Please sign in to comment.