Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada committed Feb 2, 2024
1 parent daff94c commit 301ee80
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ class Linear4bit(nn.Linear):
```
"""
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None):
super().__init__(input_features, output_features, bias, device)
"""
Initialize Linear4bit class.
Expand All @@ -269,6 +268,7 @@ def __init__(self, input_features, output_features, bias=True, compute_dtype=Non
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self)
# self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype
Expand Down Expand Up @@ -473,7 +473,6 @@ class Linear8bitLt(nn.Linear):
"""
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
super().__init__(input_features, output_features, bias, device)
"""
Initialize Linear8bitLt class.
Expand All @@ -484,7 +483,8 @@ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=
Number of output features of the linear layer.
bias (`bool`, defaults to `True`):
Whether the linear class uses the bias term as well.
"""
"""
super().__init__(input_features, output_features, bias, device)
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
Expand Down

0 comments on commit 301ee80

Please sign in to comment.