Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/multi-backend-refactor' into m…
Browse files Browse the repository at this point in the history
…ulti-backend-refactor-cpu-xpu-ops
  • Loading branch information
Xia-Weiwen committed May 7, 2024
2 parents b0dec0a + 749e06f commit 97e41b8
Show file tree
Hide file tree
Showing 24 changed files with 1,580 additions and 585 deletions.
3 changes: 3 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848

# Reformat with ruff-format
5a4263f4dc05fe8f78f4111beab9f68a81deeab1

# CHANGELOG: to reverse chron order + mdformat
4743ff0d43e04e4cc3e5d8b9e7cd016c0defa36d
4 changes: 3 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,12 @@ jobs:
os: [ubuntu-latest, windows-latest]
arch: [x86_64, aarch64]
cuda_version:
["11.7.1", "11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2"]
["11.7.1", "11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.0"]
exclude:
- os: windows-latest # This probably requires arm64 Windows agents
arch: aarch64
- os: windows-latest # The Jimver/cuda-toolkit is action used for Windows builds is not updated for 12.4 yet.
cuda_version: "12.4.0"
- os: ubuntu-latest # Temporary. Takes too long, not ready yet.
arch: aarch64
runs-on: ${{ matrix.os }} # One day, we could run them on native agents. Azure supports this now but it's planned only for Q3 2023 for hosted agents
Expand Down
511 changes: 291 additions & 220 deletions CHANGELOG.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# `bitsandbytes`

[![Downloads](https://static.pepy.tech/badge/bitsandbytes)](https://pepy.tech/project/bitsandbytes) [![Downloads](https://static.pepy.tech/badge/bitsandbytes/month)](https://pepy.tech/project/bitsandbytes) [![Downloads](https://static.pepy.tech/badge/bitsandbytes/week)](https://pepy.tech/project/bitsandbytes)

The `bitsandbytes` library is a lightweight Python wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and 8 & 4-bit quantization functions.

The library includes quantization primitives for 8-bit & 4-bit operations, through `bitsandbytes.nn.Linear8bitLt` and `bitsandbytes.nn.Linear4bit` and 8-bit optimizers through `bitsandbytes.optim` module.
Expand Down
43 changes: 37 additions & 6 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,48 @@
from .cextension import lib
from .nn import modules

register_backend("cpu", CPUBackend)
# Always register the CPU backend.
register_backend("cpu", CPUBackend())

if lib and lib.compiled_with_cuda:
from .backends.cuda import CUDABackend
from .optim import adam
# Register either CUDA or ROCm backend, if available.
# Only one of these backends can be used at a time, since the torch.device semantics are
# the same for both torch+rocm and torch+cuda (e.g. device name is "cuda")
if torch.cuda.is_available():
# TODO: Consider deferring loading of cextension - should backend class implement that?

if torch.version.cuda:
from .backends.cuda import CUDABackend

register_backend("cuda", CUDABackend())
elif torch.version.hip:
from .backends.rocm import ROCmBackend

register_backend("cuda", ROCmBackend())

# Register MPS backend, if available.
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
from .backends.mps import MPSBackend

register_backend("mps", MPSBackend())

# Register Intel XPU backend, if available.
if hasattr(torch, "xpu") and torch.xpu.is_available():
from .backends.xpu import XPUBackend

register_backend("xpu", XPUBackend())

# TODO: Other potential backends:
# XLA - Google TPU / PJRT runtime
# HPU - Habana / Intel Gaudi
# IPU - Graphcore
# NPU - Ascend
# Note that we may not map 1:1 with a device type, e.g. SYCL, XLA
# In this case, it will be up to each backend to dispatch as needed

register_backend("cuda", CUDABackend())
__pdoc__ = {
"libbitsandbytes": False,
"optim.optimizer.Optimizer8bit": False,
"optim.optimizer.MockArgs": False,
}

__version__ = "0.44.0.dev"
__version__ = "0.43.2.dev"
185 changes: 162 additions & 23 deletions bitsandbytes/backends/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple
from typing import Literal, Optional, Tuple, Union

import torch

Expand All @@ -12,48 +12,62 @@ class Backend(ABC):
@abstractmethod
def double_quant(
self,
A,
col_stats=None,
row_stats=None,
out_col=None,
out_row=None,
A: torch.Tensor,
col_stats: Optional[torch.Tensor] = None,
row_stats: Optional[torch.Tensor] = None,
out_col: Optional[torch.Tensor] = None,
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
):
raise NotImplementedError

@abstractmethod
def transform(
self,
A,
to_order,
A: torch.Tensor,
to_order: str,
from_order="row",
out=None,
out: Optional[torch.Tensor] = None,
transpose=False,
state=None,
state: Optional[Tuple[torch.Size, str]] = None,
ld=None,
):
raise NotImplementedError

@abstractmethod
def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
def igemmlt(
self,
A: torch.Tensor,
B: torch.Tensor,
SA: Tuple[torch.Size, str],
SB: Tuple[torch.Size, str],
out: Optional[torch.Tensor] = None,
Sout: Optional[Tuple[torch.Size, str]] = None,
dtype=torch.int32,
) -> Union[torch.Tensor, Tuple[Optional[Tuple[torch.Tensor, Tuple[torch.Size, str]]]]]:
raise NotImplementedError

@abstractmethod
def mm_dequant(
self,
A,
quant_state,
row_stats,
col_stats,
out=None,
new_row_stats=None,
new_col_stats=None,
bias=None,
):
A: torch.Tensor,
quant_state: Tuple[torch.Size, str],
row_stats: torch.Tensor,
col_stats: torch.Tensor,
out: Optional[torch.Tensor] = None,
new_row_stats: Optional[torch.Tensor] = None,
new_col_stats: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError

@abstractmethod
def extract_outliers(self, A, SA, idx):
def extract_outliers(
self,
A: torch.Tensor,
SA: Tuple[torch.Size, str],
idx: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError

@abstractmethod
Expand All @@ -64,7 +78,7 @@ def quantize_4bit(
out: Optional[torch.Tensor] = None,
blocksize=64,
compress_statistics=False,
quant_type="fp4",
quant_type: Literal["fp4", "nf4"] = "fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
"""
Expand Down Expand Up @@ -102,7 +116,7 @@ def dequantize_4bit(
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
quant_type="fp4",
quant_type: Literal["fp4", "nf4"] = "fp4",
) -> torch.Tensor:
"""
Dequantizes FP4 blockwise quantized values.
Expand Down Expand Up @@ -131,3 +145,128 @@ def dequantize_4bit(
Dequantized tensor.
"""
raise NotImplementedError

@abstractmethod
def gemv_4bit(
self,
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
transposed_A=False,
transposed_B=False,
state: QuantState = None,
) -> torch.Tensor:
raise NotImplementedError

@abstractmethod
def quantize_blockwise(
self,
A: torch.Tensor,
code: Optional[torch.Tensor] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=4096,
nested=False,
) -> Tuple[torch.Tensor, QuantState]:
raise NotImplementedError

@abstractmethod
def dequantize_blockwise(
self,
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 4096,
nested=False,
) -> torch.Tensor:
raise NotImplementedError

@abstractmethod
def optimizer_update_8bit_blockwise(
self,
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
eps: float,
step: int,
lr: float,
qmap1: torch.Tensor,
qmap2: Optional[torch.Tensor],
absmax1: torch.Tensor,
absmax2: Optional[torch.Tensor],
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
"""
Performs an in-place optimizer update with one or two optimizer states.
Args:
optimizer_name (`str`): The name of the optimizer, e.g. `adam`
g (`torch.Tensor`): Gradient tensor.
p (`torch.Tensor`): Parameter tensor.
state1 (`torch.Tensor`): Optimizer state 1.
state2 (`torch.Tensor`, optional): Optimizer state 2.
beta1 (`float`): Optimizer beta1.
beta2 (`float`): Optimizer beta2.
eps (`float`): Optimizer epsilon.
step (`int`): Current optimizer step.
lr (`float`): The learning rate.
qmap1 (`torch.Tensor`): Quantization map for the first state.
qmap2 (`torch.Tensor`, optional): Quantization map for the second state.
absmax1 (`torch.Tensor`): Max value for the first state update.
absmax2 (`torch.Tensor`, optional): Max value for the second state update.
weight_decay (`float`, optional): Weight decay. Defaults to 0.0.
gnorm_scale (`float`, optional): The factor to rescale the gradient to the max clip value. Defaults to 1.0.
skip_zeros (`bool`, optional): Whether to skip zero-valued gradients or not. Defaults to False.
"""
raise NotImplementedError

@abstractmethod
def optimizer_update_32bit(
self,
optimizer_name: str,
g: torch.Tensor,
p: torch.Tensor,
state1: torch.Tensor,
beta1: float,
eps: float,
step: int,
lr: float,
state2: Optional[torch.Tensor] = None,
beta2: float = 0.0,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Optional[torch.Tensor] = None,
max_unorm: float = 0.0,
skip_zeros=False,
) -> None:
"""
Performs an in-place optimizer update with one or two optimizer states.
Universal optimizer update for 32-bit state and 32/16-bit gradients/weights
Args:
optimizer_name (`str`): The name of the optimizer, e.g. `adam`
g (`torch.Tensor`): Gradient tensor.
p (`torch.Tensor`): Parameter tensor.
state1 (`torch.Tensor`): Optimizer state 1.
beta1 (`float`): Optimizer beta1.
eps (`float`): Optimizer epsilon.
step (`int`): Current optimizer step.
lr (`float`): The learning rate.
state2 (`torch.Tensor`, optional): Optimizer state 2. Defaults to None.
beta2 (`float`, optional): Optimizer beta2. Defaults to 0.0.
weight_decay (`float`, optional): Defaults to 0.0.
gnorm_scale (`float`, optional): The factor to rescale the gradient to the max clip value. Defaults to 1.0.
unorm_vec (`torch.Tensor`, optional): The tensor for the update norm. Defaults to None.
max_unorm (`float`, optional): The maximum update norm relative to the weight norm.. Defaults to 0.0.
skip_zeros (`bool`, optional): Whether to skip zero-valued gradients or not. Defaults to False.
"""
raise NotImplementedError
Loading

0 comments on commit 97e41b8

Please sign in to comment.