From 301ee803d291aa265ae0a765a84adfbd778ed030 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 2 Feb 2024 07:01:19 +0000 Subject: [PATCH] fix --- bitsandbytes/nn/modules.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index c2c19344d..4cd2da153 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -257,7 +257,6 @@ class Linear4bit(nn.Linear): ``` """ def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None): - super().__init__(input_features, output_features, bias, device) """ Initialize Linear4bit class. @@ -269,6 +268,7 @@ def __init__(self, input_features, output_features, bias=True, compute_dtype=Non bias (`bool`, defaults to `True`): Whether the linear class uses the bias term as well. """ + super().__init__(input_features, output_features, bias, device) self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self) # self.persistent_buffers = [] # TODO consider as way to save quant state self.compute_dtype = compute_dtype @@ -473,7 +473,6 @@ class Linear8bitLt(nn.Linear): """ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0, index=None, device=None): - super().__init__(input_features, output_features, bias, device) """ Initialize Linear8bitLt class. @@ -484,7 +483,8 @@ def __init__(self, input_features, output_features, bias=True, has_fp16_weights= Number of output features of the linear layer. bias (`bool`, defaults to `True`): Whether the linear class uses the bias term as well. - """ + """ + super().__init__(input_features, output_features, bias, device) assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() self.index = index