Skip to content

Commit

Permalink
finish bnbEmbedding4bit's __init__ method
Browse files Browse the repository at this point in the history
  • Loading branch information
GM-git-dotcom committed Feb 28, 2024
1 parent 10ecb8a commit 9a3dc58
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from bitsandbytes.nn import Params4bit
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import Tensor, device

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
Expand Down Expand Up @@ -368,8 +368,36 @@ def __repr__(self) -> str:
return "lora." + rep

class bnbEmbedding4bit(torch.nn.Embedding): # WORK IN PROGRESS
def __init__(self):
raise NotImplementedError
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
device: Optional[device] = None,
compute_dtype=None,
compress_statistics=True,
quant_type='fp4' #
) -> None:
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
_weight,
device=device
)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) #
self.compute_dtype = compute_dtype #
self.compute_type_is_set = False #
self.quant_state = None #

def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
Expand Down

0 comments on commit 9a3dc58

Please sign in to comment.