From 87f88af4653d4490ff0a0e32ce9af302666b9075 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 29 Jul 2024 10:07:37 -0400 Subject: [PATCH] Enable loading prequantized weights with bf16/fp16/fp32 quant_storage type for FSDP --- bitsandbytes/nn/modules.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 40766ad41..f113b3648 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -273,6 +273,7 @@ def from_prequantized( quantized_stats: Dict[str, Any], requires_grad: bool = False, device="cuda", + module: Optional["Linear4bit"] = None, **kwargs, ) -> "Params4bit": self = torch.Tensor._make_subclass(cls, data.to(device)) @@ -284,6 +285,10 @@ def from_prequantized( self.bnb_quantized = True self.quant_storage = data.dtype + self.module = module + + if self.module is not None: + self.module.quant_state = self.quant_state return self