diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 40766ad41..f113b3648 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -273,6 +273,7 @@ def from_prequantized( quantized_stats: Dict[str, Any], requires_grad: bool = False, device="cuda", + module: Optional["Linear4bit"] = None, **kwargs, ) -> "Params4bit": self = torch.Tensor._make_subclass(cls, data.to(device)) @@ -284,6 +285,10 @@ def from_prequantized( self.bnb_quantized = True self.quant_storage = data.dtype + self.module = module + + if self.module is not None: + self.module.quant_state = self.quant_state return self