Skip to content

Commit

Permalink
merge changes from main
Browse files Browse the repository at this point in the history
  • Loading branch information
Titus-von-Koeller committed May 3, 2024
2 parents d62516f + 5b9ef77 commit 7f13c8f
Show file tree
Hide file tree
Showing 18 changed files with 518 additions and 258 deletions.
3 changes: 3 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ ea7c14f8ef64924f2d0ff80df3cdabf2c7299848

# Reformat with ruff-format
5a4263f4dc05fe8f78f4111beab9f68a81deeab1

# CHANGELOG: to reverse chron order + mdformat
4743ff0d43e04e4cc3e5d8b9e7cd016c0defa36d
4 changes: 3 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,12 @@ jobs:
os: [ubuntu-latest, windows-latest]
arch: [x86_64, aarch64]
cuda_version:
["11.7.1", "11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2"]
["11.7.1", "11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.0"]
exclude:
- os: windows-latest # This probably requires arm64 Windows agents
arch: aarch64
- os: windows-latest # The Jimver/cuda-toolkit is action used for Windows builds is not updated for 12.4 yet.
cuda_version: "12.4.0"
- os: ubuntu-latest # Temporary. Takes too long, not ready yet.
arch: aarch64
runs-on: ${{ matrix.os }} # One day, we could run them on native agents. Azure supports this now but it's planned only for Q3 2023 for hosted agents
Expand Down
511 changes: 291 additions & 220 deletions CHANGELOG.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# `bitsandbytes`

[![Downloads](https://static.pepy.tech/badge/bitsandbytes)](https://pepy.tech/project/bitsandbytes) [![Downloads](https://static.pepy.tech/badge/bitsandbytes/month)](https://pepy.tech/project/bitsandbytes) [![Downloads](https://static.pepy.tech/badge/bitsandbytes/week)](https://pepy.tech/project/bitsandbytes)

The `bitsandbytes` library is a lightweight Python wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and 8 & 4-bit quantization functions.

The library includes quantization primitives for 8-bit & 4-bit operations, through `bitsandbytes.nn.Linear8bitLt` and `bitsandbytes.nn.Linear4bit` and 8-bit optimizers through `bitsandbytes.optim` module.
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@
"optim.optimizer.MockArgs": False,
}

__version__ = "0.44.0.dev"
__version__ = "0.43.2.dev"
7 changes: 4 additions & 3 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,11 +866,12 @@ def get_4bit_type(typename, device=None, blocksize=64):
if data is None:
raise NotImplementedError(f"Typename {typename} not supported")

data = Tensor(data)
data /= data.abs().max()
data = torch.tensor(data, device=device)
data.div_(data.abs().max())

assert data.numel() == 16

return data.to(device)
return data


def quantize_fp4(
Expand Down
29 changes: 25 additions & 4 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from bitsandbytes.autograd._functions import get_tile_inds, undo_layout
from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer
from bitsandbytes.utils import (
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
)

T = TypeVar("T", bound="torch.nn.Module")

Expand Down Expand Up @@ -619,6 +623,16 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
return
weight_format = state_dict.pop(f"{prefix}weight_format", "row")

if isinstance(weight_format, torch.Tensor):
weight_format = weight_format.item()

# For new weights format storage type, we explicitly check
# if weights_format is on the mapping
if isinstance(weight_format, int) and weight_format not in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
raise ValueError(f"Expected supported weight format - got {weight_format}")
elif isinstance(weight_format, int) and weight_format in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format]

if weight_format != "row":
tile_indices = get_tile_inds(weight_format, weight.device)
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
Expand Down Expand Up @@ -711,13 +725,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
if not self.state.has_fp16_weights:
if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
destination[format_name] = "row"
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None and not layout_reordered:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = "row"
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = self.state.formatB
weights_format = self.state.formatB
# At this point `weights_format` is an str
if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
raise ValueError(f"Unrecognized weights format {weights_format}")

weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format]

destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8)

def _load_from_state_dict(
self,
Expand Down
4 changes: 4 additions & 0 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,7 @@ def __eq__(self, other):
else self.state2 is other.state2
)
)


LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3}
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()}
2 changes: 1 addition & 1 deletion csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;

if(blocksize == 4096)
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, 0><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, DATA_TYPE><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 2048)
kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 1024)
Expand Down
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
title: 8-bit optimizers
- local: algorithms
title: Algorithms
- local: fsdp_qlora
title: FSDP-QLoRA
- local: integrations
title: Integrations
- local: errors
Expand Down
106 changes: 106 additions & 0 deletions docs/source/fsdp_qlora.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# FSDP-QLoRA

FSDP-QLoRA combines data parallelism (FSDP enables sharding model parameters, optimizer states, and gradients across GPUs), 4-bit quantization, and LoRA to train LLMs up to 70B parameters on a dual 24GB GPU system. This technique was released by [Answer.AI](https://www.answer.ai/posts/2024-03-06-fsdp-qlora) in collaboration with bitsandbytes to make training LLMs more efficient and accessible for everyone.

This guide provides a brief guide on how bitsandbytes supports storing quantized weights to enable FSDP-QLoRA, and how to run training with the Hugging Face libraries.

> [!TIP]
> Other changes required for bitsandbytes to support FSDP-QLoRA, such as reconstructing the weights from the quantization metadata and preventing quantizing already quantized weights when they're moved from a CPU to GPU, are documented in this [Pull Request](https://github.com/TimDettmers/bitsandbytes/pull/970) and described in the [Enabling 70B Finetuning on Consumer GPUs](https://www.answer.ai/posts/2024-03-14-fsdp-qlora-deep-dive) blog post. We highly recommend reading these resources for a better understanding of FSDP-QLoRA!
## Quantized data storage

FSDP only supports sharding float data types which can be problematic because quantized weights are typically stored as integer data types (uint8). bitsandbytes doesn't have this problem because it uses `StoreChar` to read and write quantized weights regardless of the data type storage. This makes it simple to add a `quant_storage` parameter to the [`~nn.Linear4bit`] and [`~nn.Params4bit`] classes and set it to `torch.uint8` to maintain backward compatibility with the codebase.

```py
import torch
import bitsandbytes as bnb

model = bnb.nn.Linear4bit(
input_features,
output_features,
quant_type="fp4",
quant_storage=torch.uint8,
)
```

With the `quant_storage` parameter, you can select any of the FSDP supported data types to shard [`~nn.Linear4bit`] with such as bfloat16, float16 or float32.

## Training

bitsandbytes is deeply integrated with the Hugging Face ecosystem, making it easy to use with libraries like [Transformers](https://hf/co/docs/transformers), [PEFT](https://hf/co/docs/peft), and [TRL](https://hf/co/docs/trl).

Before you begin, make sure you have the latest libraries installed.

```bash
pip install -U bitsandbytes accelerate transformers peft trl
```

> [!TIP]
> PEFT provides a configuration file ([fsdp_config_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/fsdp_config_qlora.yaml)), launch command ([run_peft_qlora_fsdp.sh](https://github.com/huggingface/peft/blob/main/examples/sft/run_peft_qlora_fsdp.sh)), and training script ([train.py](https://github.com/huggingface/peft/blob/main/examples/sft/train.py)) for FSDP-QLoRA. To learn more, check out the [Use PEFT QLoRA and FSDP for finetuning large models on multiple GPUs](https://huggingface.co/docs/peft/main/en/accelerate/fsdp#use-peft-qlora-and-fsdp-for-finetuning-large-models-on-multiple-gpus) documentation.
The important change that enables FSDP-QLoRA training is the `bnb_4bit_quant_storage` parameter in the [`~transformers.BitsAndBytesConfig`] class. This allows you to set the storage data type of the quantized weights to a float data type.

```py
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.bfloat16,
)
```

Pass the [`~transformers.BitsAndBytesConfig`] to a model to set it up for FSDP-QLoRA. You should set the `torch_dtype` parameter to match `bnb_4bit_quant_storage` so that the [`~nn.Linear4bit`] layers are wrapped identically to the `Linear` layers. If the storage types do not match, then each [`~nn.Linear4bit`] layer is wrapped individually.

```py
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
)
```

Configure the [`~peft.LoraConfig`] class for QLoRA training by setting `target_modules="all-linear"`.

```py
from peft import LoraConfig

peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear",
)
```

Now you can pass everything to the [`~trl.SFTTrainer`] for training.

```py
from trl import SFTTrainer

trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_arguments,
)
trainer.train()
```

## Resources

To learn more about FSDP and QLoRA, check out the following resources:

- The [AnswerDotAI/fsdp_qlora](https://github.com/AnswerDotAI/fsdp_qlora) repository.
- The introductory [You can now train a 70b language model at home](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) blog post by Answer.AI.
- For an introduction to FSDP, read the [Introducing PyTorch Fully Sharded Data Parallel (FSDP) API](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api) blog post.
- For more details about QLoRA, take a look at the [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) blog post.
43 changes: 32 additions & 11 deletions docs/source/installation.mdx
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
# Installation

bitsandbytes is only supported on CUDA GPUs for CUDA versions **11.0 - 12.3**. Select your operating system below to see the installation instructions.
bitsandbytes is only supported on CUDA GPUs for CUDA versions **11.0 - 12.3**.

<hfoptions id="OS system">
<hfoption id="Linux">
The latest version of bitsandbytes (v0.43.0) builds on:

| OS | CUDA | Compiler |
|---|---|---|
| Linux | 11.7 - 12.3 | GCC 11.4 |
| | 12.4+ | GCC 13.2 |
| Windows | 11.7 - 12.4 | MSVC 19.38+ (VS2022 17.8.0+) |

> [!TIP]
> MacOS support is still a work in progress! Subscribe to this [issue](https://github.com/TimDettmers/bitsandbytes/issues/1020) to get notified about discussions and to track the integration progress.
For Linux systems, make sure your hardware meets the following requirements to use bitsandbytes features.

Expand All @@ -23,13 +31,26 @@ pip install bitsandbytes

## Compile from source

For Linux and Windows systems, you can compile bitsandbytes from source. Installing from source allows for more build options with different CMake configurations.

<hfoptions id="source">
<hfoption id="Linux">

To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. Make sure you have a compiler installed to compile C++ (gcc, make, headers, etc.). For example, to install a compiler and CMake on Ubuntu:

```bash
apt-get install -y build-essential cmake
```

You should also install CUDA Toolkit by following the [NVIDIA CUDA Installation Guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) guide from NVIDIA.
You should also install CUDA Toolkit by following the [NVIDIA CUDA Installation Guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) guide from NVIDIA. The current expected CUDA Toolkit version is **11.1+** and it is recommended to install **GCC >= 7.3** and required to have at least **GCC >= 6**.

Refer to the following table if you're using another CUDA Toolkit version.

| CUDA Toolkit | GCC |
|---|---|
| >= 11.4.1 | >= 11 |
| >= 12.0 | >= 12 |
| >= 12.4 | >= 13 |

Now to install the bitsandbytes package from source, run the following commands:

Expand All @@ -49,7 +70,13 @@ pip install .

Windows systems require Visual Studio with C++ support as well as an installation of the CUDA SDK.

You'll need to build bitsandbytes from source. To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. You should also install CUDA Toolkit by following the [CUDA Installation Guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) guide from NVIDIA.
To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. You should also install CUDA Toolkit by following the [CUDA Installation Guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) guide from NVIDIA.

Refer to the following table if you're using another CUDA Toolkit version.

| CUDA Toolkit | MSVC |
|---|---|
| >= 11.6 | 19.30+ (VS2022) |

```bash
git clone https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/
Expand All @@ -61,12 +88,6 @@ python -m build --wheel

Big thanks to [wkpark](https://github.com/wkpark), [Jamezo97](https://github.com/Jamezo97), [rickardp](https://github.com/rickardp), [akx](https://github.com/akx) for their amazing contributions to make bitsandbytes compatible with Windows.

</hfoption>
<hfoption id="MacOS">

> [!TIP]
> MacOS support is still a work in progress! Subscribe to this [issue](https://github.com/TimDettmers/bitsandbytes/issues/1020) to get notified about discussions and to track the integration progress.
</hfoption>
</hfoptions>

Expand Down
2 changes: 1 addition & 1 deletion docs/source/integrations.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ With Transformers, it's very easy to load any model in 4 or 8-bit and quantize t
For example, to load and quantize a model to 4-bits and use the bfloat16 data type for compute:

> [!WARNING]
> bfloat16 is the optimal compute data type if your hardware supports it. The default is float32 for backward compatibility and numerical stability, but it can often lead to numerical instabilities. bfloat16 provides the best of both worlds, numerical stability equivalent to float32, but combined with the memory footprint and significant computation speedup of a 16-bit data type. Make sure to check if your hardware supports bfloat16 and if it does, configure it using the `bnb_4bit_compute_dtype` parameter in [`~transformers.BitsAndBytesConfig`]!
> bfloat16 is the ideal `compute_dtype` if your hardware supports it. While the default `compute_dtype`, float32, ensures backward compatibility (due to wide-ranging hardware support) and numerical stability, it is large and slows down computations. In contrast, float16 is smaller and faster but can lead to numerical instabilities. bfloat16 combines the best aspects of both; it offers the numerical stability of float32 and the reduced memory footprint and speed of a 16-bit data type. Check if your hardware supports bfloat16 and configure it using the `bnb_4bit_compute_dtype` parameter in [`~transformers.BitsAndBytesConfig`]!
```py
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
Expand Down
8 changes: 4 additions & 4 deletions requirements-ci.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Requirements used for GitHub actions
pytest==8.1.1
einops==0.7.0
lion-pytorch==0.1.2
pytest==8.2.0
einops==0.8.0
lion-pytorch==0.1.4
scipy==1.10.1; python_version < "3.9"
scipy==1.12.0; python_version >= "3.9"
scipy==1.13.0; python_version >= "3.9"
12 changes: 6 additions & 6 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Requirements used for local development
setuptools>=63
pytest~=8.1.1
einops~=0.7.0
pytest~=8.2.0
einops~=0.8.0
wheel~=0.43.0
lion-pytorch~=0.1.2
scipy~=1.12.0
pandas~=2.2.1
matplotlib~=3.8.3
lion-pytorch~=0.1.4
scipy~=1.13.0
pandas~=2.2.2
matplotlib~=3.8.4
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def has_ext_modules(self):

setup(
name="bitsandbytes",
version="0.44.0.dev",
version="0.43.2.dev",
author="Tim Dettmers",
author_email="[email protected]",
description="k-bit optimizers and matrix multiplication routines.",
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import gc

import pytest
import torch

Expand All @@ -20,6 +22,13 @@ def pytest_runtest_call(item):
raise


@pytest.hookimpl(trylast=True)
def pytest_runtest_teardown(item, nextitem):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()


@pytest.fixture(scope="session")
def requires_cuda() -> bool:
cuda_available = torch.cuda.is_available()
Expand Down
Loading

0 comments on commit 7f13c8f

Please sign in to comment.