Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip #1

Open
wants to merge 10 commits into
base: origin_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bitsandbytes/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
Linear8bitLt,
LinearFP4,
LinearNF4,
Embedding8bit,
Embedding4bit,
EmbeddingFP4,
EmbeddingNF4,
OutlierAwareLinear,
Params4bit,
StableEmbedding,
Expand Down
225 changes: 212 additions & 13 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,23 @@ def to(self, *args, **kwargs):
return new_param


def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
if getattr(module.weight, "quant_state", None) is not None:
return

if getattr(module, "quant_state", None) is None:
warnings.warn(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
)

# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert module.weight.shape[1] == 1
if not isinstance(module.weight, Params4bit):
module.weight = Params4bit(module.weight, quant_storage=module.quant_storage)
module.weight.quant_state = module.quant_state


class Linear4bit(nn.Linear):
"""
This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314).
Expand Down Expand Up @@ -440,22 +457,12 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
destination[prefix + "weight." + k] = v if keep_vars else v.detach()

def forward(self, x: torch.Tensor):
fix_4bit_weight_quant_state_from_module(self)

# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)

if getattr(self.weight, "quant_state", None) is None:
if getattr(self, "quant_state", None) is not None:
# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1
if not isinstance(self.weight, Params4bit):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
self.weight.quant_state = self.quant_state
else:
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
)

if not self.compute_type_is_set:
self.set_compute_type(x)
self.compute_type_is_set = True
Expand Down Expand Up @@ -649,6 +656,198 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)


class Embedding8bit(nn.Embedding):
"""
This class implements [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm for embedding layer

Quantization API is similar to Linear8bitLt:
```python
import torch
import torch.nn as nn

from bitsandbytes.nn import Embedding8bit

fp16_module = nn.Embedding(128, 64)
int8_module = Embedding8bit(128, 64)

int8_module.load_state_dict(fp16_module.state_dict())

int8_module = int8_module.to(0) # Quantization happens here
```
"""
def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype)
self.dtype = self.weight.data.dtype

self.weight = Int8Params(self.weight.data, has_fp16_weights=False, requires_grad=False)

def _save_to_state_dict(self, destination, prefix, keep_vars):
raise NotImplementedError("saving Embedding4bit module is not implemented")

def forward(self, input: Tensor) -> Tensor:
if not hasattr(self.weight, 'SCB'):
raise RuntimeError(
"Embedding layer is not quantized. Please call .cuda() or .to(device) first."
)

rows = self.weight.data
row_stats = self.weight.SCB

assert rows.shape == (self.num_embeddings, self.embedding_dim)
assert row_stats.shape == (self.num_embeddings,)

compressed_output = F.embedding(input, rows)
compressed_output_stats = F.embedding(input, row_stats.view(self.num_embeddings, 1))

output = compressed_output * (compressed_output_stats / 127.0)

return output.to(self.dtype)


class Embedding4bit(nn.Embedding):
"""
This is the base class similar to Linear4bit. It implements 4 bit 4-bit quantization algorithm presented in
[QLoRA](https://arxiv.org/abs/2305.14314) for embeddings.

Quantization API is similar to Linear4bit:
```python
import torch
import torch.nn as nn

from bitsandbytes.nn import Embedding8bit

fp16_module = nn.Embedding(128, 64)
quantized_module = Embedding4bit(128, 64)

quantized_module.load_state_dict(fp16_module.state_dict())

quantized_module = quantized_module.to(0) # Quantization happens here
```
"""
def __init__(
self,
num_embeddings,
embedding_dim,
dtype=None,
quant_type="fp4",
quant_storage=torch.uint8,
device=None,
):
super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype)
self.dtype = self.weight.data.dtype

self.weight = Params4bit(
self.weight.data,
requires_grad=False,
compress_statistics=None,
quant_type=quant_type,
quant_storage=quant_storage,
module=self,
)

blocksize = self.weight.blocksize

if embedding_dim % blocksize != 0:
warnings.warn(
f"Embedding size {embedding_dim} is not divisible by block size {blocksize}. "
"This will lead to slow inference.",
)


def _forward_with_partial_dequantize(self, input: Tensor):
assert self.embedding_dim % self.weight.quant_state.blocksize == 0

w_4bit_uint8 = (
self.weight.data.view(torch.uint8)
.view(self.num_embeddings * self.embedding_dim // 2, 1)
)

output_4bit = torch.nn.functional.embedding(
weight=w_4bit_uint8.view(self.num_embeddings, self.embedding_dim // 2),
input=input,
).view(-1, 1)
assert output_4bit.shape == (input.numel() * self.embedding_dim // 2, 1)


blocks_per_emb = self.embedding_dim // self.weight.blocksize

absmax = self.weight.quant_state.absmax
assert absmax.shape == (self.num_embeddings * blocks_per_emb,)

output_absmax = torch.nn.functional.embedding(
weight=absmax.view(self.num_embeddings, blocks_per_emb),
input=input,
).view(-1,)
assert output_absmax.shape == (input.numel() * blocks_per_emb,)

output_quant_state = copy.deepcopy(self.weight.quant_state)
output_quant_state.absmax = output_absmax
output_quant_state.shape = torch.Size((*input.shape, self.embedding_dim))

output = bnb.functional.dequantize_4bit(
output_4bit, output_quant_state
)
assert output.shape == (*input.shape, self.embedding_dim)

return output.to(self.dtype)

def _save_to_state_dict(self, destination, prefix, keep_vars):
raise NotImplementedError("saving Embedding4bit module is not implemented")

def forward(self, input: Tensor) -> Tensor:
fix_4bit_weight_quant_state_from_module(self)

if self.embedding_dim % self.weight.quant_state.blocksize == 0:
return self._forward_with_partial_dequantize(input)

dequantized_weight = bnb.functional.dequantize_4bit(
self.weight.data, self.weight.quant_state
)

return torch.nn.functional.embedding(
weight=dequantized_weight,
input=input,
).to(self.dtype)


class EmbeddingFP4(Embedding4bit):
def __init__(
self,
num_embeddings,
embedding_dim,
dtype=None,
quant_storage=torch.uint8,
device=None,
):
super().__init__(
num_embeddings,
embedding_dim,
dtype=dtype,
quant_type="fp4",
quant_storage=quant_storage,
device=device,
)


class EmbeddingNF4(Embedding4bit):
def __init__(
self,
num_embeddings,
embedding_dim,
dtype=None,
quant_storage=torch.uint8,
device=None,
):
super().__init__(
num_embeddings,
embedding_dim,
dtype=dtype,
quant_type="nf4",
quant_storage=quant_storage,
device=device,
)


class Linear8bitLt(nn.Linear):
"""
This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm.
Expand Down
Loading