diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 698c21481..241f90fca 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -101,6 +101,7 @@ jobs: name: shared_library_cuda_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.cuda_version }} path: output/* retention-days: 7 + build-wheels: needs: - build-shared-libs @@ -121,7 +122,7 @@ jobs: runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 - - name: Download build artifact + - name: Download build artifacts uses: actions/download-artifact@v4 with: merge-multiple: true @@ -152,6 +153,54 @@ jobs: path: dist/bitsandbytes-*.whl retention-days: 7 + upload-pre-release-wheels: + name: Create release and upload artifacts + runs-on: ubuntu-latest + permissions: + contents: write + needs: + - build-wheels + steps: + - name: Download artifacts to tmp directory + uses: actions/download-artifact@v4 + with: + path: tmp/ + pattern: "bdist_wheel_*" + merge-multiple: true + - name: Inspect tmp directory after downloading artifacts + run: ls -alFR tmp/ + - name: Move and rename wheel files + run: | + mkdir -p wheels/ + find tmp/ -type f -name '*.whl' -print0 | while IFS= read -r -d '' wheel; do + wheel_filename=$(basename "$wheel") + if [[ $wheel_filename == *linux*x86_64* ]]; then + mv "$wheel" wheels/bnb-linux-x86_64.whl + elif [[ $wheel_filename == *linux*aarch64* ]]; then + mv "$wheel" wheels/bnb-linux-aarch64.whl + elif [[ $wheel_filename == *macosx*x86_64* ]]; then + mv "$wheel" wheels/bnb-macos-x86_64.whl + elif [[ $wheel_filename == *macosx*arm64* ]]; then + mv "$wheel" wheels/bnb-macos-arm64.whl + elif [[ $wheel_filename == *win*amd64* ]]; then + mv "$wheel" wheels/bnb-windows-x86_64.whl + else + echo "Unknown wheel format: $wheel_filename" + exit 1 + fi + done + - name: Inspect wheels directory after renaming files + run: ls -alFR wheels/ + - name: Create release and upload artifacts + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_CONTINUOUS_RELEASE_TYPE: prerelease + GITHUB_CONTINUOUS_RELEASE_TAG: continuous-release_main + run: | + wget -q https://github.com/TheAssassin/pyuploadtool/releases/download/continuous/pyuploadtool-x86_64.AppImage + chmod +x pyuploadtool-x86_64.AppImage + ./pyuploadtool-x86_64.AppImage --appimage-extract-and-run wheels/*.whl + audit-wheels: needs: build-wheels runs-on: ubuntu-latest diff --git a/.github/workflows/upload_pr_documentation.yml b/.github/workflows/upload_pr_documentation.yml index 6497caf2d..707705297 100644 --- a/.github/workflows/upload_pr_documentation.yml +++ b/.github/workflows/upload_pr_documentation.yml @@ -6,6 +6,10 @@ on: types: - completed +permissions: + contents: read + pull-requests: write # Allows posting comments on pull requests + jobs: build: uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main diff --git a/CHANGELOG.md b/CHANGELOG.md index ed324f09e..fb69ff376 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,12 @@ +### 0.43.3 + +#### Improvements: + +- FSDP: Enable loading prequantized weights with bf16/fp16/fp32 quant_storage + - Background: This update, linked to [Transformer PR #32276](https://github.com/huggingface/transformers/pull/32276), allows loading prequantized weights with alternative storage formats. Metadata is tracked similarly to `Params4bit.__new__` post PR #970. It supports models exported with non-default `quant_storage`, such as [this NF4 model with BF16 storage](https://huggingface.co/hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16). + - Special thanks to @winglian and @matthewdouglas for enabling FSDP+QLoRA finetuning of Llama 3.1 405B on a single 8xH100 or 8xA100 node with as little as 256GB system RAM. + + ### 0.43.2 This release is quite significant as the QLoRA bug fix big implications for higher `seqlen` and batch sizes. diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f3914456..d305e5a3e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,6 +69,11 @@ if(BUILD_CUDA) # This needs to be added *before* we try to enable the CUDA language so CMake's compiler check passes. if(MSVC AND MSVC_VERSION VERSION_GREATER_EQUAL 1940) string(APPEND CMAKE_CUDA_FLAGS " --allow-unsupported-compiler") + + # This is needed to build with VS2022 17.11+ and CUDA < 12.4. + if (MSVC_VERSION VERSION_GREATER_EQUAL 1941) + string(APPEND CMAKE_CUDA_FLAGS " -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH") + endif() endif() enable_language(CUDA) # This will fail if CUDA is not found diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index a8acfbfc5..78c99355b 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -21,4 +21,4 @@ "optim.optimizer.MockArgs": False, } -__version__ = "0.43.3.dev" +__version__ = "0.44.0.dev" diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index c8ae7358d..45573538e 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -20,6 +20,7 @@ import logging import os from pathlib import Path +import re import torch @@ -44,13 +45,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: override_value = os.environ.get("BNB_CUDA_VERSION") if override_value: - library_name_stem, _, library_name_ext = library_name.rpartition(".") - # `library_name_stem` will now be e.g. `libbitsandbytes_cuda118`; - # let's remove any trailing numbers: - library_name_stem = library_name_stem.rstrip("0123456789") - # `library_name_stem` will now be e.g. `libbitsandbytes_cuda`; - # let's tack the new version number and the original extension back on. - library_name = f"{library_name_stem}{override_value}.{library_name_ext}" + library_name = re.sub("cuda\d+", f"cuda{override_value}", library_name, count=1) logger.warning( f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index cea3179a1..4b9b02506 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -439,6 +439,11 @@ def is_on_gpu(tensors): return on_gpu +def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream: + stream = torch.cuda.current_stream(tensor.device) + return stream + + def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: """ Get the ctypes pointer from a PyTorch Tensor. @@ -973,6 +978,7 @@ def dequantize_blockwise( f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", ) is_on_gpu([A, absmax, out]) + stream = get_tensor_stream(A) if out.dtype == torch.float32: lib.cdequantize_blockwise_fp32( get_ptr(quant_state.code), @@ -981,6 +987,7 @@ def dequantize_blockwise( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()), + stream, # Used the _as_parameter_ attribute of torch.cuda.Stream, Similarly for the following ) elif out.dtype == torch.float16: lib.cdequantize_blockwise_fp16( @@ -990,6 +997,7 @@ def dequantize_blockwise( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()), + stream, ) elif out.dtype == torch.bfloat16: lib.cdequantize_blockwise_bf16( @@ -999,6 +1007,7 @@ def dequantize_blockwise( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()), + stream, ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") @@ -1176,7 +1185,6 @@ def quantize_4bit( prev_device = pre_call(A.device) is_on_gpu([A, out, absmax]) - if A.dtype == torch.float32: if quant_type == "fp4": lib.cquantize_blockwise_fp32_fp4( @@ -1356,6 +1364,7 @@ def dequantize_4bit( device = pre_call(A.device) is_on_gpu([A, absmax, out]) + stream = get_tensor_stream(A) if out.dtype == torch.float32: if quant_state.quant_type == "fp4": lib.cdequantize_blockwise_fp32_fp4( @@ -1365,6 +1374,7 @@ def dequantize_4bit( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n), + stream, ) else: lib.cdequantize_blockwise_fp32_nf4( @@ -1374,6 +1384,7 @@ def dequantize_4bit( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n), + stream, ) elif out.dtype == torch.float16: if quant_state.quant_type == "fp4": @@ -1384,6 +1395,7 @@ def dequantize_4bit( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n), + stream, ) else: lib.cdequantize_blockwise_fp16_nf4( @@ -1393,6 +1405,7 @@ def dequantize_4bit( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n), + stream, ) elif out.dtype == torch.bfloat16: if quant_state.quant_type == "fp4": @@ -1403,6 +1416,7 @@ def dequantize_4bit( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n), + stream, ) else: lib.cdequantize_blockwise_bf16_nf4( @@ -1412,6 +1426,7 @@ def dequantize_4bit( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n), + stream, ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") @@ -1518,7 +1533,8 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = if out is None: out = torch.zeros_like(A, dtype=torch.float32) is_on_gpu([code, A, out]) - lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) + stream = get_tensor_stream(A) + lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream) post_call(prev_device) return out @@ -2002,7 +2018,7 @@ def gemv_4bit( lda = ct.c_int32(lda) ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) - + stream = get_tensor_stream(A) if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if A.dtype == torch.float16: lib.cgemm_4bit_inference_naive_fp16( @@ -2018,6 +2034,7 @@ def gemv_4bit( ldb, ldc, ct.c_int32(state.blocksize), + stream, ) elif A.dtype == torch.bfloat16: lib.cgemm_4bit_inference_naive_bf16( @@ -2033,6 +2050,7 @@ def gemv_4bit( ldb, ldc, ct.c_int32(state.blocksize), + stream, ) elif A.dtype == torch.float32: lib.cgemm_4bit_inference_naive_fp32( @@ -2048,6 +2066,7 @@ def gemv_4bit( ldb, ldc, ct.c_int32(state.blocksize), + stream, ) else: raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 96f4359bf..20aff67a3 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -4,6 +4,10 @@ # LICENSE file in the root directory of this source tree. from .modules import ( Embedding, + Embedding4bit, + Embedding8bit, + EmbeddingFP4, + EmbeddingNF4, Int8Params, Linear4bit, Linear8bitLt, diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 40766ad41..6c78494aa 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -273,6 +273,7 @@ def from_prequantized( quantized_stats: Dict[str, Any], requires_grad: bool = False, device="cuda", + module: Optional["Linear4bit"] = None, **kwargs, ) -> "Params4bit": self = torch.Tensor._make_subclass(cls, data.to(device)) @@ -284,6 +285,10 @@ def from_prequantized( self.bnb_quantized = True self.quant_storage = data.dtype + self.module = module + + if self.module is not None: + self.module.quant_state = self.quant_state return self @@ -342,6 +347,23 @@ def to(self, *args, **kwargs): return new_param +def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]): + if getattr(module.weight, "quant_state", None) is not None: + return + + if getattr(module, "quant_state", None) is None: + warnings.warn( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.", + ) + + # the quant state got lost when the parameter got converted. This happens for example for fsdp + # since we registered the module, we can recover the state here + assert module.weight.shape[1] == 1 + if not isinstance(module.weight, Params4bit): + module.weight = Params4bit(module.weight, quant_storage=module.quant_storage, bnb_quantized=True) + module.weight.quant_state = module.quant_state + + class Linear4bit(nn.Linear): """ This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314). @@ -444,22 +466,12 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): destination[prefix + "weight." + k] = v if keep_vars else v.detach() def forward(self, x: torch.Tensor): + fix_4bit_weight_quant_state_from_module(self) + # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) - if getattr(self.weight, "quant_state", None) is None: - if getattr(self, "quant_state", None) is not None: - # the quant state got lost when the parameter got converted. This happens for example for fsdp - # since we registered the module, we can recover the state here - assert self.weight.shape[1] == 1 - if not isinstance(self.weight, Params4bit): - self.weight = Params4bit(self.weight, quant_storage=self.quant_storage, bnb_quantized=True) - self.weight.quant_state = self.quant_state - else: - print( - "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.", - ) if not self.compute_type_is_set: self.set_compute_type(x) self.compute_type_is_set = True @@ -653,6 +665,191 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices) +class Embedding8bit(nn.Embedding): + """ + This class implements [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm for embedding layer + + Quantization API is similar to Linear8bitLt: + ```python + import torch + import torch.nn as nn + + from bitsandbytes.nn import Embedding8bit + + fp16_module = nn.Embedding(128, 64) + int8_module = Embedding8bit(128, 64) + + int8_module.load_state_dict(fp16_module.state_dict()) + + int8_module = int8_module.to(0) # Quantization happens here + ``` + """ + + def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None): + super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype) + self.dtype = self.weight.data.dtype + + self.weight = Int8Params(self.weight.data, has_fp16_weights=False, requires_grad=False) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + raise NotImplementedError("Saving Embedding8bit module is not implemented") + + def forward(self, input: Tensor) -> Tensor: + if not hasattr(self.weight, "SCB"): + raise RuntimeError("Embedding layer is not quantized. Please call .cuda() or .to(device) first.") + + rows = self.weight.data + row_stats = self.weight.SCB + + assert rows.shape == (self.num_embeddings, self.embedding_dim) + assert row_stats.shape == (self.num_embeddings,) + + compressed_output = F.embedding(input, rows) + compressed_output_stats = F.embedding(input, row_stats.view(self.num_embeddings, 1)) + + output = compressed_output * (compressed_output_stats / 127.0) + + return output.to(self.dtype) + + +class Embedding4bit(nn.Embedding): + """ + This is the base class similar to Linear4bit. It implements the 4-bit quantization algorithm presented in + [QLoRA](https://arxiv.org/abs/2305.14314) for embeddings. + + Quantization API is similar to Linear4bit: + ```python + import torch + import torch.nn as nn + + from bitsandbytes.nn import Embedding4bit + + fp16_module = nn.Embedding(128, 64) + quantized_module = Embedding4bit(128, 64) + + quantized_module.load_state_dict(fp16_module.state_dict()) + + quantized_module = quantized_module.to(0) # Quantization happens here + ``` + """ + + def __init__( + self, + num_embeddings, + embedding_dim, + dtype=None, + quant_type="fp4", + quant_storage=torch.uint8, + device=None, + ): + super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype) + self.dtype = self.weight.data.dtype + + self.weight = Params4bit( + self.weight.data, + requires_grad=False, + compress_statistics=None, + quant_type=quant_type, + quant_storage=quant_storage, + module=self, + ) + + blocksize = self.weight.blocksize + + if embedding_dim % blocksize != 0: + warnings.warn( + f"Embedding size {embedding_dim} is not divisible by block size {blocksize}. " + "This will lead to slow inference.", + ) + + def _forward_with_partial_dequantize(self, input: Tensor): + assert self.embedding_dim % self.weight.quant_state.blocksize == 0 + + w_4bit_uint8 = self.weight.data.view(torch.uint8).view(self.num_embeddings * self.embedding_dim // 2, 1) + + output_4bit = torch.nn.functional.embedding( + weight=w_4bit_uint8.view(self.num_embeddings, self.embedding_dim // 2), + input=input, + ).view(-1, 1) + assert output_4bit.shape == (input.numel() * self.embedding_dim // 2, 1) + + blocks_per_emb = self.embedding_dim // self.weight.blocksize + + absmax = self.weight.quant_state.absmax + assert absmax.shape == (self.num_embeddings * blocks_per_emb,) + + output_absmax = torch.nn.functional.embedding( + weight=absmax.view(self.num_embeddings, blocks_per_emb), + input=input, + ).view( + -1, + ) + assert output_absmax.shape == (input.numel() * blocks_per_emb,) + + output_quant_state = copy.deepcopy(self.weight.quant_state) + output_quant_state.absmax = output_absmax + output_quant_state.shape = torch.Size((*input.shape, self.embedding_dim)) + + output = bnb.functional.dequantize_4bit(output_4bit, output_quant_state) + assert output.shape == (*input.shape, self.embedding_dim) + + return output.to(self.dtype) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + raise NotImplementedError("Saving Embedding4bit module is not implemented") + + def forward(self, input: Tensor) -> Tensor: + fix_4bit_weight_quant_state_from_module(self) + + if self.embedding_dim % self.weight.quant_state.blocksize == 0: + return self._forward_with_partial_dequantize(input) + + dequantized_weight = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + + return torch.nn.functional.embedding( + weight=dequantized_weight, + input=input, + ).to(self.dtype) + + +class EmbeddingFP4(Embedding4bit): + def __init__( + self, + num_embeddings, + embedding_dim, + dtype=None, + quant_storage=torch.uint8, + device=None, + ): + super().__init__( + num_embeddings, + embedding_dim, + dtype=dtype, + quant_type="fp4", + quant_storage=quant_storage, + device=device, + ) + + +class EmbeddingNF4(Embedding4bit): + def __init__( + self, + num_embeddings, + embedding_dim, + dtype=None, + quant_storage=torch.uint8, + device=None, + ): + super().__init__( + num_embeddings, + embedding_dim, + dtype=dtype, + quant_type="nf4", + quant_storage=quant_storage, + device=device, + ) + + class Linear8bitLt(nn.Linear): """ This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm. diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index e9c857d49..f6fcd171d 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -153,12 +153,14 @@ def fill_qmap(self): def __setstate__(self, state): super().__setstate__(state) - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict, move_to_device=True): """Load an optimizer state. Arguments: state_dict (`dict`): An optimizer state (should be returned from a call to `state_dict`) to load. + move_to_device (`bool`, defaults to `True`): + Whether to move the optimizer's state to the device. """ # deepcopy, to be consistent with module API state_dict = deepcopy(state_dict) @@ -195,7 +197,8 @@ def cast(param, value): elif isinstance(value, dict): for k, v in value.items(): if k in self.non_castable_tensor_keys: - value[k] = v.to(param.device) + if move_to_device: + value[k] = v.to(param.device) else: value[k] = cast(param, v) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index e4d459961..0f8ec4b7e 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -20,6 +20,7 @@ #define NUM 4 #define NUM_BLOCK 4096 +__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { @@ -462,50 +463,6 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran } } -template -__device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, float lower, float midpoint, float upper) -{ - int lower_pivot = QUADRANT*16-1 - 0; - int pivot = QUADRANT*16-1 + 16; - int upper_pivot = QUADRANT*16-1 + 31; - - float val = midpoint; - - // i>>=1 = {32, 16, 8, 4, 2, 1} - for(int i = 16; i > 0; i>>=1) - { - if(x > val) - { - lower_pivot = pivot; - lower = val; - pivot+=i; - } - else - { - upper_pivot = pivot; - upper = val; - pivot-=i; - } - val = smem_code[pivot]; - } - - if(x > val) - { - midpoint = (upper+val)*0.5f; - if(x > midpoint) - return upper_pivot; - else - return pivot; - } - else - { - midpoint = (lower+val)*0.5f; - if(x < midpoint) - return lower_pivot; - else - return pivot; - } -} __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) { @@ -519,86 +476,6 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index } } -template -__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n) -{ - typedef cub::WarpReduce WarpReduce; - __shared__ typename WarpReduce::TempStorage temp_storage; - typedef cub::BlockLoad LoadT; - __shared__ typename LoadT::TempStorage loadt; - - const int warp_idx = threadIdx.x/32; - const int valid_items = n - (blockIdx.x*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (blockIdx.x*BLOCK_SIZE); - - // BLOCK_SIZE/32 == number of warps - __shared__ int smem_max_indices[8*BLOCK_SIZE/32]; - __shared__ float smem_max_values[8*BLOCK_SIZE/32]; - - T values[8]; - T max1 = -64000.0f; - T max2 = -64000.0f; - int max_idx1 = -1; - int max_idx2 = -1; - int sign1 = -1; - int sign2 = -1; - - // 1. load 8 values per thread - // 2. compute 2-max in registers (64 max per warp) - // 3. do warp reduction + broadcast back - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - // 5. Repeat (3) 8 times for top 8 values in 256 - // 6. store with byte index - - LoadT(loadt).Load(&(A[(blockIdx.x*BLOCK_SIZE)]), values, valid_items, (T)0.0f); - #pragma unroll 8 - for(int i = 0; i < 8; i++) - { - T absval = fabsf(values[i]); - if(absval > max1) - { - max1 = values[i]; - sign1 = signbit(values[i]); - max_idx1 = 8*threadIdx.x + i; - } - else if(absval > max2) - { - max2 = values[i]; - sign2 = signbit(values[i]); - max_idx2 = 8*threadIdx.x + i; - } - } - - float warp_max; - for(int i = 0; i < 8; i++) - { - // 3. do warp reduction + broadcast back - warp_max = WarpReduce(temp_storage).Reduce(max1, cub::Max()); - warp_max = cub::ShuffleIndex<32>(warp_max, 0, 0xffffffff); - - // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest - if(warp_max == max1) - { - smem_max_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1; - smem_max_indices[warp_idx*8 + i] = max_idx1; - - sign1 = sign2; - max1 = max2; - max_idx1 = max_idx2; - - max2 = -64000.0f; - } - __syncwarp(); - } - - if(threadIdx.x % 32 < 8) - { - // offset: 8 values per 256 input values - // - int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8; - } - -} - #define THREADS_ESTIMATE 512 #define NUM_ESTIMATE 8 #define BLOCK_ESTIMATE 4096 @@ -1560,7 +1437,8 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; switch(OPTIMIZER) { - case MOMENTUM: + case ADAGRAD: + case MOMENTUM: if(step == 1) s1_vals[j] = (float)g_vals[j]; else @@ -1663,6 +1541,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, if(weight_decay > 0.0f) { switch(OPTIMIZER) { + case ADAGRAD: case MOMENTUM: case RMSPROP: g_val += ((float)p_vals[j])*weight_decay; @@ -1675,8 +1554,8 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; - switch(OPTIMIZER) - { + switch(OPTIMIZER){ + case ADAGRAD: case MOMENTUM: if(step == 1) s1_vals[j] = g_vals[j]; @@ -3055,45 +2934,6 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * } } - -//template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) -//{ -//// element-wise kernel -//// 1. Load batch x k into registers -//// 2. Load k x k into registers -//// 3. dequantize and store in second pair of k x k -//// 4. matmul -//// 5. sum with cub -//// 6. store outputs -//// TC kernel -//// use k warps per thread block -//// 1. threadblock use read-only cache to read in register tile for A into shared memory -//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments -//// 3. each warp reads a segment of values 16x32 from B -//// 4. do dequantization from register of B into second pair of registers -//// 5. store (4) into fragment -//// 6. matmul aggregate into fragment C -//// 7. aggregate files of C into shared memory block C -//// 8. sum (7) -//// 9. write outputs to matmul output matrix -//} - -template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f) -{ - if(limit_base + ITEMS <= limit) - reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; - else - { - for(int k = 0; k < ITEMS; k++) - { - if(limit_base + k < limit) - local[k] = buffer[idx+k]; - else - local[k] = (T)zero_value; - } - } -} - #define WARPS 3 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3311,13 +3151,28 @@ template __device__ void printnonzero(T *A, int num_values, const c printf("%s %i %f\n", strval, i, (float)A[i]); } -template __device__ void printnonzero(float *A, int num_values, const char*strval); -template __device__ void printnonzero(half *A, int num_values, const char*strval); -__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) { + //// element-wise kernel + //// 1. Load batch x k into registers + //// 2. Load k x k into registers + //// 3. dequantize and store in second pair of k x k + //// 4. matmul + //// 5. sum with cub + //// 6. store outputs + //// TC kernel + //// use k warps per thread block + //// 1. threadblock use read-only cache to read in register tile for A into shared memory + //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments + //// 3. each warp reads a segment of values 16x32 from B + //// 4. do dequantization from register of B into second pair of registers + //// 5. store (4) into fragment + //// 6. matmul aggregate into fragment C + //// 7. aggregate files of C into shared memory block C + //// 8. sum (7) + //// 9. write outputs to matmul output matrix #if __CUDA_ARCH__ >= 750 using namespace nvcuda; int col_offset = blockIdx.x *32; @@ -3911,6 +3766,8 @@ MAKE_PreconditionStatic8bit1State(RMSPROP, half) MAKE_PreconditionStatic8bit1State(RMSPROP, float) MAKE_PreconditionStatic8bit1State(LION, half) MAKE_PreconditionStatic8bit1State(LION, float) +MAKE_PreconditionStatic8bit1State(ADAGRAD, half) +MAKE_PreconditionStatic8bit1State(ADAGRAD, float) #define MAKE_optimizerStatic8bit1State(oname, gtype) \ template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ @@ -3930,6 +3787,9 @@ MAKE_optimizerStatic8bit1State(RMSPROP, half) MAKE_optimizerStatic8bit1State(RMSPROP, float) MAKE_optimizerStatic8bit1State(LION, half) MAKE_optimizerStatic8bit1State(LION, float) +MAKE_optimizerStatic8bit1State(ADAGRAD, half) +MAKE_optimizerStatic8bit1State(ADAGRAD, float) + #define MAKE_PreconditionStatic8bit2State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ @@ -4075,3 +3935,6 @@ MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) + +template __device__ void printnonzero(float *A, int num_values, const char*strval); +template __device__ void printnonzero(half *A, int num_values, const char*strval); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index a7fe3d700..15f31cbed 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -9,7 +9,6 @@ #ifndef kernels #define kernels -//template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 3a6ffdda8..ade3b13d1 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -44,11 +44,11 @@ void quantize(float *code, float *A, unsigned char *out, int n) CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void dequantize(float *code, unsigned char *A, float *out, int n) +void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream) { int num_blocks = n/1024; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; - kDequantize<<>>(code, A, out, n); + kDequantize<<>>(code, A, out, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -76,28 +76,21 @@ template void quantizeBlockwise(floa CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, cudaStream_t stream) { + // printf("stream==%d\n",stream); int num_blocks = n/blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; int tile_size = (DATA_TYPE > 0) ? 1024 : 512; - if(DATA_TYPE > 0) - kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n); + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize/2, n); else - kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n); + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) -//{ -// int num_blocks = (colsB+32-1)/32; -// kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); -// CUDA_CHECK_RETURN(cudaPeekAtLastError()); -//} - template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, @@ -362,10 +355,6 @@ template int get_leading_dim(int dim1, int dim2) } } -template int get_leading_dim(int dim1, int dim2); -template int get_leading_dim(int dim1, int dim2); -template int get_leading_dim(int dim1, int dim2); - template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) { #ifdef NO_CUBLASLT @@ -411,15 +400,6 @@ template void trans #endif } -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); - template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { #ifdef NO_CUBLASLT @@ -693,9 +673,9 @@ template void gemm_host(int m, int n, int k, T * A, T* B, T * out //cout << m << endl; //cout << n << endl; //cout << k << endl; - //if(bits == 32) + if(bits == 32) //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); if(bits == 16) //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); gemm_device<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); @@ -724,12 +704,11 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) { int num_blocks = (m+3)/4; - - kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + kgemm_4bit_inference_naive<<< num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -753,9 +732,9 @@ template void func(float *A, float *B, float value, long n); template void func(float *A, float *B, float value, long n); template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); +template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); +template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); @@ -795,15 +774,15 @@ template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __n template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); -template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); -template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); #define MAKE_optimizer32bit(name, gtype) \ template void optimizer32bit(gtype* g, gtype* p, \ @@ -842,6 +821,9 @@ MAKE_optimizerStatic8bit(RMSPROP, half) MAKE_optimizerStatic8bit(RMSPROP, float) MAKE_optimizerStatic8bit(LION, half) MAKE_optimizerStatic8bit(LION, float) +MAKE_optimizerStatic8bit(ADAGRAD, half) +MAKE_optimizerStatic8bit(ADAGRAD, float) + #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ @@ -850,6 +832,7 @@ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g MAKE_optimizerStatic8bitBlockwise(half, ADAM); MAKE_optimizerStatic8bitBlockwise(float, ADAM); +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); @@ -863,4 +846,15 @@ MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); -MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); + +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 8b9a4f449..8d936fd43 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -7,6 +7,7 @@ #ifndef ops_H #define ops_H +#include #include #include #include @@ -142,9 +143,9 @@ class ContextCusparse template void estimateQuantiles(T *A, float *code, float offset, int n); void quantize(float *code, float *A, unsigned char *out, int n); -void dequantize(float *code, unsigned char *A, float *out, int n); +void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream); template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, cudaStream_t stream); template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, @@ -195,7 +196,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); template void func(T *A, T *B, T value, long n); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index ea2283504..1da522bfd 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -31,14 +31,14 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } -void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) +{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } -void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) +{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } -void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) +{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } #define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ @@ -126,17 +126,17 @@ void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } -void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ -void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ -void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } \ +void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } \ +void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } \ -void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } -void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } -void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } -void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n); } -void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n); } -void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); } #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ @@ -195,11 +195,11 @@ extern "C" void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } - void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } + void cdequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream){ dequantize(code, A, out, n, stream); } - void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); } void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } @@ -209,17 +209,17 @@ extern "C" void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); } void cquantize_blockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); } #define MAKE_CFUNC32(name, gtype, gbits) \ void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \ @@ -405,14 +405,14 @@ extern "C" CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) - void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize) - { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) + { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } - void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) - { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) + { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } - void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize) - { gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) + { gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } #endif diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index fdfe19ee4..a72eb1967 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -12,6 +12,8 @@ title: 8-bit optimizers - local: algorithms title: Algorithms + - local: non_cuda_backends + title: Non-CUDA compute backends - local: fsdp_qlora title: FSDP-QLoRA - local: integrations diff --git a/docs/source/contributing.mdx b/docs/source/contributing.mdx index 4fe6b7541..5da42961e 100644 --- a/docs/source/contributing.mdx +++ b/docs/source/contributing.mdx @@ -5,8 +5,9 @@ ### Setup pre-commit hooks - Install pre-commit hooks with `pip install pre-commit`. -- Run `pre-commit autoupdate` once to configure the hooks. -- Re-run `pre-commit autoupdate` every time a new hook got added. +- Run `pre-commit install` once to install the hooks, so they will be run on every commit. +- If the hooks introduce changes, they'll be visible with `git diff`. Review them and `git add` them if everything is fine, then re-execute the before commit, it should pass now. +- If you want to manually trigger the hooks, you may do `pre-commit run --all-files` Now all the pre-commit hooks will be automatically run when you try to commit and if they introduce some changes, you need to re-add the changed files before being able to commit and push. diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 8187fbf81..175c3e29a 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -2,7 +2,7 @@ ## CUDA -bitsandbytes is only supported on CUDA GPUs for CUDA versions **11.0 - 12.5**. +bitsandbytes is only supported on CUDA GPUs for CUDA versions **11.0 - 12.5**. However, there's a multi-backend effort under way which is currently in alpha release, check [the respective section below in case you're interested to help us with early feedback](#multi-backend). The latest version of bitsandbytes builds on: @@ -31,7 +31,7 @@ To install from PyPI. pip install bitsandbytes ``` -### Compile from source +### Compile from source[[compile]] For Linux and Windows systems, you can compile bitsandbytes from source. Installing from source allows for more build options with different CMake configurations. @@ -61,7 +61,7 @@ git clone https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ pip install -r requirements-dev.txt cmake -DCOMPUTE_BACKEND=cuda -S . make -pip install . +pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out) ``` > [!TIP] @@ -85,7 +85,7 @@ git clone https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ pip install -r requirements-dev.txt cmake -DCOMPUTE_BACKEND=cuda -S . cmake --build . --config Release -python -m build --wheel +pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out) ``` Big thanks to [wkpark](https://github.com/wkpark), [Jamezo97](https://github.com/Jamezo97), [rickardp](https://github.com/rickardp), [akx](https://github.com/akx) for their amazing contributions to make bitsandbytes compatible with Windows. @@ -129,55 +129,89 @@ For example, to use a local install path: ```bash export BNB_CUDA_VERSION=117 -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/tim/local/cuda-11.7 +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-11.7 ``` 3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 11.7) and a different bitsandbytes library is loaded. - -## Intel CPU +## Multi-backend[[multi-backend]] > [!TIP] -> Intel CPU backend only supports building from source; for now, please follow the instructions below. +> This functionality is currently in preview and therefore not yet production-ready! Please reference [this guide](./non_cuda_backends) for more in-depth information about the different backends and their current status. -Like CUDA, you can compile bitsandbytes from source for Linux and Windows systems. Installing from source allows for more build options with different CMake configurations. +Please follow these steps to install bitsandbytes with device-specific backend support other than CUDA: - - +### Pip install the pre-built wheel (recommended for most) -To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. Make sure you have a compiler installed to compile C++ (gcc, make, headers, etc.). For example, to install a compiler and CMake on Ubuntu: +WIP (will be added in the coming days) -```bash -apt-get install -y build-essential cmake -``` +### Compilation -We recommend installing **GCC >= 11** and have at least **GCC >= 6**. + + -Now to install the bitsandbytes package from source, run the following commands: +#### AMD GPU + +bitsandbytes is fully supported from ROCm 6.1 onwards (currently in alpha release). + +> [!TIP] +> If you would like to install ROCm and PyTorch on bare metal, skip Docker steps and refer to our official guides at [ROCm installation overview](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/install-overview.html#rocm-install-overview) and [Installing PyTorch for ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/3rd-party/pytorch-install.html#using-wheels-package) (Step 3 of wheels build for quick installation). Please make sure to get PyTorch wheel for the installed ROCm version. ```bash -git clone --branch multi-backend-refactor https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ +# Create a docker container with latest ROCm image, which includes ROCm libraries +docker pull rocm/dev-ubuntu-22.04:6.1.2-complete +docker run -it --device=/dev/kfd --device=/dev/dri --group-add video rocm/dev-ubuntu-22.04:6.1.2-complete +apt-get update && apt-get install -y git && cd home + +# Install pytorch compatible with above ROCm version +pip install torch --index-url https://download.pytorch.org/whl/rocm6.1/ + +# Install bitsandbytes from PyPI +# (This is supported on Ubuntu 22.04, Python 3.10, ROCm 6.1.0/6.1.1/6.1.2 and gpu arch - gfx90a, gfx942, gfx1100 +# Please install from source if your configuration doesn't match with these) +pip install bitsandbytes + +# Install bitsandbytes from source +# Clone bitsandbytes repo, ROCm backend is currently enabled on multi-backend-refactor branch +git clone --depth 1 -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ + +# Install dependencies pip install -r requirements-dev.txt -pip install intel_extension_for_pytorch -cmake -DCOMPUTE_BACKEND=cpu -S . + +# Compile & install +apt-get install -y build-essential cmake # install build tools dependencies, unless present +cmake -DCOMPUTE_BACKEND=hip -S . # Use -DBNB_ROCM_ARCH="gfx90a;gfx942" to target specific gpu arch make -pip install . +pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out) ``` - + -Windows systems require Visual Studio with C++ support. +#### Intel CPU -To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. +> [!TIP] +> Intel CPU backend only supports building from source; for now, please follow the instructions below. -```bash -git clone --branch multi-backend-refactor https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ +Similar to the CUDA case, you can compile bitsandbytes from source for Linux and Windows systems. + +The below commands are for Linux. For installing on Windows, please adapt the below commands according to the same pattern as described [the section above on compiling from source under the Windows tab](#compile). + +``` +git clone --depth 1 -b multi-backend-refactor https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ +pip install intel_extension_for_pytorch pip install -r requirements-dev.txt cmake -DCOMPUTE_BACKEND=cpu -S . -cmake --build . --config Release -pip install . +make +pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out) ``` + + + +#### Apple Silicon + +WIP + diff --git a/docs/source/non_cuda_backends.mdx b/docs/source/non_cuda_backends.mdx new file mode 100644 index 000000000..fca586534 --- /dev/null +++ b/docs/source/non_cuda_backends.mdx @@ -0,0 +1,27 @@ +# Multi-backend support (non-CUDA backends) + +As part of a recent refactoring effort, we will soon offer official multi-backend support. Currently, this feature is available in a preview alpha release, allowing us to gather early feedback from users to improve the functionality and identify any bugs. + +At present, the Intel CPU and AMD ROCm backends are considered fully functional. The Intel XPU backend has limited functionality and is less mature. + +Please refer to the [installation instructions](./installation#multi-backend) for details on installing the backend you intend to test (and hopefully provide feedback on). + +> [!Tip] +> Apple Silicon support is planned for Q4 2024. We are actively seeking contributors to help implement this, develop a concrete plan, and create a detailed list of requirements. Due to limited resources, we rely on community contributions for this implementation effort. To discuss further, please spell out your thoughts and discuss in [this GitHub discussion](https://github.com/bitsandbytes-foundation/bitsandbytes/discussions/1340) and tag `@Titus-von-Koeller` and `@matthewdouglas`. Thank you! + +## Alpha Release + +As we are currently in the alpha testing phase, bugs are expected, and performance might not meet expectations. However, this is exactly what we want to discover from **your** perspective as the end user! + +Please share and discuss your feedback with us here: + +- [Github Discussion: Multi-backend refactor: Alpha release ( AMD ROCm ONLY )](https://github.com/bitsandbytes-foundation/bitsandbytes/discussions/1339) +- [Github Discussion: Multi-backend refactor: Alpha release ( Intel ONLY )](https://github.com/bitsandbytes-foundation/bitsandbytes/discussions/1338) + +Thank you for your support! + +## Benchmarks + +### Intel + +### AMD diff --git a/include/Portable.h b/include/Portable.h index 090a25065..59b6dc840 100644 --- a/include/Portable.h +++ b/include/Portable.h @@ -26,8 +26,31 @@ typedef struct {float a; float b; float c; float d;} __m128; typedef struct {int a; int b; int c; int d;} __m128i; typedef struct {double a; double b;} __m128d; #endif +#elif defined(__powerpc64__) +#ifdef __CUDACC__ +#undef USE_VSX // Doesn't work with nvcc, undefined symbols +#else +#include +#undef USE_VSX // Not yet implemented +#endif +#undef USE_AVX // x86_64 only +#undef USE_AVX2 // x86_64 only +#undef USE_SSE2 // x86_64 only +#undef USE_SSE41 // x86_64 only +#undef USE_SSE42 // x86_64 only +#undef USE_FMA // x86_64 only +#ifdef USE_VSX +typedef vector float __m128; +typedef vector signed int __m128i; +typedef vector double __m128d; +#else +typedef struct {float a; float b; float c; float d;} __m128; +typedef struct {int a; int b; int c; int d;} __m128i; +typedef struct {double a; double b;} __m128d; +#endif #else #undef USE_NEON // ARM64 only +#undef USE_VSX // PPC only #ifdef __FMA__ #define USE_FMA #endif diff --git a/include/SIMD.h b/include/SIMD.h index 9d1410c73..6222af04f 100644 --- a/include/SIMD.h +++ b/include/SIMD.h @@ -41,7 +41,7 @@ template <> struct InstrFloatTraits } } -#if !defined(__aarch64__) +#if !defined(__aarch64__) && !defined(__powerpc64__) #ifdef USE_SSE42 #ifndef _MSC_VER #include diff --git a/include/Type.h b/include/Type.h index 16bf3e3ae..2c0ab6ed6 100644 --- a/include/Type.h +++ b/include/Type.h @@ -10,7 +10,7 @@ using std::size_t; namespace BinSearch { -enum InstrSet { Scalar, SSE, AVX, Neon }; +enum InstrSet { Scalar, SSE, AVX, Neon, VSX }; #define ALGOENUM(x, b) x, enum Algos diff --git a/requirements-ci.txt b/requirements-ci.txt index 182e1023e..25ff67295 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,6 +1,6 @@ # Requirements used for GitHub actions -pytest==8.3.1 +pytest==8.3.3 einops==0.8.0 lion-pytorch==0.2.2 scipy==1.10.1; python_version < "3.9" -scipy==1.14.0; python_version >= "3.9" +scipy==1.14.1; python_version >= "3.9" diff --git a/requirements-dev.txt b/requirements-dev.txt index 41211880c..aedd07966 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,9 @@ # Requirements used for local development setuptools>=63 -pytest~=8.3.1 +pytest~=8.3.3 einops~=0.8.0 -wheel~=0.43.0 +wheel~=0.44.0 lion-pytorch~=0.2.2 -scipy~=1.14.0 +scipy~=1.14.1 pandas~=2.2.2 -matplotlib~=3.9.1 +matplotlib~=3.9.2 diff --git a/setup.py b/setup.py index 18de0fe5b..beba00922 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def has_ext_modules(self): setup( name="bitsandbytes", - version="0.43.3.dev", + version="0.44.0.dev", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="k-bit optimizers and matrix multiplication routines.", @@ -37,7 +37,7 @@ def has_ext_modules(self): install_requires=["torch", "numpy"], extras_require={ "benchmark": ["pandas", "matplotlib"], - "test": ["scipy"], + "test": ["scipy", "lion_pytorch"], }, long_description=read("README.md"), long_description_content_type="text/markdown", diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index fc79a54b0..b13f8b6c6 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -33,6 +33,12 @@ def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? +def test_get_cuda_bnb_library_path_override_nocublaslt(monkeypatch, cuda111_noblas_spec, caplog): + monkeypatch.setenv("BNB_CUDA_VERSION", "125") + assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda125_nocublaslt" + assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? + + def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt" diff --git a/tests/test_modules.py b/tests/test_modules.py index 9d507c6b4..2176f1d48 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -1,3 +1,4 @@ +import inspect import math import einops @@ -616,7 +617,97 @@ def test_fp8linear(): assert bgraderr < 0.00002 -def test_4bit_warnings(requires_cuda): +@pytest.mark.parametrize("embedding_dim", [64, 65]) +@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str) +@pytest.mark.parametrize( + "embedding_class,quant_storage", + [ + (bnb.nn.Embedding8bit, None), + (bnb.nn.EmbeddingFP4, torch.uint8), + (bnb.nn.EmbeddingFP4, torch.float32), + (bnb.nn.EmbeddingNF4, torch.uint8), + (bnb.nn.EmbeddingNF4, torch.float32), + ], + ids=lambda x: x.__name__ if inspect.isclass(x) else str(x), +) +def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_storage): + num_embeddings = 128 + + src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to( + torch.float32 + ) * 2 - 1 # Embeddings filled with {-1, 1} values. It should compress losslessly + + emb_base = nn.Embedding( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + _freeze=True, + _weight=src_weight, + ) + if embedding_class is bnb.nn.Embedding8bit: + e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + else: + e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim, quant_storage=quant_storage) + + e.load_state_dict(emb_base.state_dict()) + + emb_base.cuda() + e.cuda() + + input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda") + + torch.testing.assert_close( + actual=e(input_tokens), + expected=emb_base(input_tokens), + ) + + +@pytest.mark.parametrize("embedding_dim", [64, 65]) +@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str) +@pytest.mark.parametrize( + "embedding_class,quant_storage", + [ + (bnb.nn.Embedding8bit, None), + (bnb.nn.EmbeddingFP4, torch.uint8), + (bnb.nn.EmbeddingFP4, torch.float32), + (bnb.nn.EmbeddingNF4, torch.uint8), + (bnb.nn.EmbeddingNF4, torch.float32), + ], + ids=lambda x: x.__name__ if inspect.isclass(x) else str(x), +) +def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_storage): + is_8bit = embedding_class is bnb.nn.Embedding8bit + + num_embeddings = 128 + + src_weight = torch.rand((num_embeddings, embedding_dim), dtype=torch.float32) + + emb_base = nn.Embedding( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + _freeze=True, + _weight=src_weight, + ) + if is_8bit: + e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + else: + e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim, quant_storage=quant_storage) + + e.load_state_dict(emb_base.state_dict()) + + emb_base.cuda() + e.cuda() + + input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda") + + torch.testing.assert_close( + actual=e(input_tokens), + expected=emb_base(input_tokens), + atol=0.05 if is_8bit else 0.20, + rtol=0.0, + ) + + +def test_4bit_linear_warnings(): dim1 = 64 with pytest.warns(UserWarning, match=r"inference or training"): @@ -642,3 +733,58 @@ def test_4bit_warnings(requires_cuda): net(inp) assert len(record) == 2 + + +def test_4bit_embedding_warnings(): + num_embeddings = 128 + default_block_size = 64 + + with pytest.warns(UserWarning, match=r"inference."): + net = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=default_block_size + 1) + net.cuda() + inp = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda") + net(inp) + + +def test_4bit_embedding_weight_fsdp_fix(): + num_embeddings = 64 + embedding_dim = 32 + + module = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + + module.cuda() + + module.weight.quant_state = None + + input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda") + + module(input_tokens) + + assert module.weight.quant_state is not None + + +def test_4bit_linear_weight_fsdp_fix(): + inp_size = 64 + out_size = 32 + + module = bnb.nn.Linear4bit(inp_size, out_size) + + module.cuda() + + module.weight.quant_state = None + + input_tensor = torch.randn((1, inp_size), device="cuda") + + module(input_tensor) + + assert module.weight.quant_state is not None + + +def test_embedding_not_implemented_error(): + with pytest.raises(NotImplementedError): + emb = bnb.nn.Embedding4bit(32, 32) + emb.state_dict() + + with pytest.raises(NotImplementedError): + emb = bnb.nn.Embedding8bit(32, 32) + emb.state_dict()