Skip to content

Commit bfa0e33

Browse files
ran black and isort for coherent code formatting
1 parent 597a852 commit bfa0e33

25 files changed

+3866
-1998
lines changed

bitsandbytes/__init__.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
# Copyright (c) Facebook, Inc. and its affiliates.
2-
#
3-
# This source code is licensed under the MIT license found in the
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .nn import modules
7-
from .autograd._functions import mm_cublas, bmm_cublas, matmul_cublas, matmul, MatmulLtState
6+
from .autograd._functions import (MatmulLtState, bmm_cublas, matmul,
7+
matmul_cublas, mm_cublas)
88
from .cextension import COMPILED_WITH_CUDA
9+
from .nn import modules
910

1011
if COMPILED_WITH_CUDA:
1112
from .optim import adam
1213

13-
__pdoc__ = {'libbitsandbytes': False,
14-
'optim.optimizer.Optimizer8bit': False,
15-
'optim.optimizer.MockArgs': False
16-
}
14+
__pdoc__ = {
15+
"libbitsandbytes": False,
16+
"optim.optimizer.Optimizer8bit": False,
17+
"optim.optimizer.MockArgs": False,
18+
}

bitsandbytes/autograd/_functions.py

Lines changed: 88 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
1+
from dataclasses import dataclass
2+
13
import torch
4+
25
import bitsandbytes as bnb
36
import bitsandbytes.functional as F
47

5-
from dataclasses import dataclass
6-
78
tensor = torch.Tensor
89

9-
'''
10+
"""
1011
This class pools outlier dimensions across layers.
1112
This is particularly important for small models where outlier features
1213
are less systematic and occur with low frequency.
13-
'''
14+
"""
15+
16+
1417
class GlobalOutlierPooler(object):
1518
_instance = None
1619

1720
def __init__(self):
18-
raise RuntimeError('Call get_instance() instead')
21+
raise RuntimeError("Call get_instance() instead")
1922

2023
def initialize(self):
2124
self.outliers = set()
@@ -29,25 +32,29 @@ def get_instance(cls):
2932
return cls._instance
3033

3134
def add_outliers(self, outlier_idx, feature_dim):
32-
if self.model_dim is None: self.model_dim = feature_dim
33-
if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer
35+
if self.model_dim is None:
36+
self.model_dim = feature_dim
37+
if feature_dim != self.model_dim:
38+
return # we do not encode outliers for the 2nd FFN layer
3439

3540
self.outliers.update(outlier_idx.tolist())
3641

3742
def get_current_outlier_idx(self):
3843
return torch.Tensor(list(self.outliers)).to(torch.int64)
3944

40-
class MatMul8bit(torch.autograd.Function):
4145

46+
class MatMul8bit(torch.autograd.Function):
4247
@staticmethod
43-
def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]):
48+
def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):
4449

4550
if precision[0] != 8:
4651
with torch.no_grad():
4752
output = torch.matmul(A, B)
4853
else:
49-
if len(B.shape) == 2: dim = 0
50-
else: dim = 1
54+
if len(B.shape) == 2:
55+
dim = 0
56+
else:
57+
dim = 1
5158
qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
5259
qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
5360
iout = F.igemm(qA, qB)
@@ -84,21 +91,41 @@ def backward(ctx, grad_output):
8491
else:
8592
if len(B.shape) == 2 and len(A.shape) == 3:
8693
grad_output = grad_output.contiguous()
87-
if not grad_output.is_contiguous(): grad_output.contiguous()
88-
qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type)
89-
if not A.is_contiguous(): A = A.contiguous()
90-
qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
94+
if not grad_output.is_contiguous():
95+
grad_output.contiguous()
96+
qgrad_output, S1 = F.vectorwise_quant(
97+
grad_output.view(-1, grad_output.shape[2]),
98+
dim=0,
99+
quant_type=quant_type,
100+
)
101+
if not A.is_contiguous():
102+
A = A.contiguous()
103+
qA, S2 = F.vectorwise_quant(
104+
A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
105+
)
91106
igrad_B = F.igemm(qA.t(), qgrad_output)
92-
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
107+
grad_B = F.vectorwise_mm_dequant(
108+
igrad_B, S2.t(), S1, grad_output.dtype, quant_type
109+
)
93110
else:
94-
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
111+
qgrad_output, S1 = F.vectorwise_quant(
112+
grad_output, dim=dims, quant_type=quant_type
113+
)
95114
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
96115
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
97-
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type)
116+
grad_B = F.vectorwise_mm_dequant(
117+
igrad_B,
118+
S2.permute(permute_dim),
119+
S1,
120+
grad_output.dtype,
121+
quant_type,
122+
)
98123

99124
if A.requires_grad:
100-
if len(grad_output.shape) == 3: dims = [2]
101-
else: dims = [1]
125+
if len(grad_output.shape) == 3:
126+
dims = [2]
127+
else:
128+
dims = [1]
102129

103130
if len(B.shape) == 3:
104131
# bio -> boi
@@ -113,10 +140,14 @@ def backward(ctx, grad_output):
113140
with torch.no_grad():
114141
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
115142
else:
116-
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
143+
qgrad_output, S1 = F.vectorwise_quant(
144+
grad_output, dim=dims, quant_type=quant_type
145+
)
117146
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
118147
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
119-
grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type)
148+
grad_A = F.vectorwise_mm_dequant(
149+
igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type
150+
)
120151

121152
return grad_A, grad_B, None, None, None
122153

@@ -125,6 +156,7 @@ def backward(ctx, grad_output):
125156
bmm_cublas = MatMul8bit.apply
126157
matmul_cublas = MatMul8bit.apply
127158

159+
128160
@dataclass
129161
class MatmulLtState:
130162
CB = None
@@ -159,7 +191,6 @@ def reset_grads(self):
159191

160192

161193
class MatMul8bitLt(torch.autograd.Function):
162-
163194
@staticmethod
164195
def forward(ctx, A, B, out=None, state=MatmulLtState()):
165196
# 1. Quantize A
@@ -171,11 +202,15 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
171202
requires_gradB = B.requires_grad
172203
formatB = state.formatB
173204
input_shape = A.shape
174-
if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance()
175-
assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!'
205+
if state.outlier_pool is None:
206+
state.outlier_pool = GlobalOutlierPooler.get_instance()
207+
assert (
208+
A.dtype == torch.float16
209+
), f"The input data type needs to be fp16 but {A.dtype} was found!"
176210

177211
# 1. Quantize A
178-
if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous()
212+
if len(A.shape) == 3:
213+
A = A.view(-1, A.shape[-1]).contiguous()
179214
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
180215

181216
if state.threshold > 0.0 and coo_tensorA is not None:
@@ -191,8 +226,8 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
191226
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
192227
# we also need to convert it to the turing/ampere format
193228
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
194-
#state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
195-
#if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
229+
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
230+
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
196231
# # generate outlier index and subB
197232
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
198233
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
@@ -203,24 +238,24 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
203238
# state.idx = outlier_idx
204239
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
205240

206-
#if state.idx is not None:
241+
# if state.idx is not None:
207242
# # extract outliers
208243
# CA[:, state.idx] = 0
209244
# CAt[:, state.idx] = 0
210245
# subA = A[:, state.idx]
211-
#else:
246+
# else:
212247
# subA = None
213248
else:
214249
if not state.has_fp16_weights and state.CxB is None:
215250
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
216251
subA = None
217252

218-
219253
# 2. Quantize B
220254
if state.has_fp16_weights:
221-
has_grad = (True if (getattr(B, 'grad', None) is not None) else False)
255+
has_grad = True if (getattr(B, "grad", None) is not None) else False
222256
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
223-
if is_transposed: B = B.contiguous()
257+
if is_transposed:
258+
B = B.contiguous()
224259

225260
if (state.is_training and not has_grad) or state.CxB is None:
226261
state.reset_grads()
@@ -234,14 +269,16 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
234269

235270
outlier_idx = torch.unique(coo_tensorA.colidx)
236271
state.idx = outlier_idx
237-
#state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
238-
#if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
272+
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
273+
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
239274
# # do not use pool for 2nd FFN layer
240275
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
241-
#else:
276+
# else:
242277
# state.idx = outlier_idx
243278
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
244-
state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half()
279+
state.subB = (
280+
(outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half()
281+
)
245282
CA[:, state.idx.long()] = 0
246283
CAt[:, state.idx.long()] = 0
247284
subA = A[:, state.idx.long()]
@@ -254,7 +291,7 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
254291
output_shape = (input_shape[0], shapeB[0])
255292

256293
# 3. Matmul
257-
C32A, SA = F.transform(CA, 'col32')
294+
C32A, SA = F.transform(CA, "col32")
258295
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
259296
output = F.mm_dequant(out32, Sout32, SCA, state.SCB)
260297

@@ -277,7 +314,7 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
277314
ctx.tensor_states = (None, None)
278315
ctx.save_for_backward(None, None)
279316

280-
#clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
317+
# clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
281318
clone_func = torch.clone
282319
return clone_func(output.view(output_shape))
283320

@@ -288,7 +325,7 @@ def backward(ctx, grad_output):
288325
SCAt, idx = ctx.tensor_states
289326
formatB = ctx.formatB
290327
state = ctx.state
291-
assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.'
328+
assert state.has_fp16_weights, "Backprop only supported for fp16 weights."
292329

293330
if len(grad_output.shape) == 3:
294331
grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
@@ -298,28 +335,33 @@ def backward(ctx, grad_output):
298335
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
299336
if req_gradB:
300337
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
301-
C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True)
338+
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
302339
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
303340
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
304341
if state.threshold > 0.0 and subA is not None:
305342
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
306343

307344
if req_gradA:
308-
C32grad, Sgrad = F.transform(Cgrad, 'col32')
345+
C32grad, Sgrad = F.transform(Cgrad, "col32")
309346
if state.CxBt is None:
310-
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
347+
state.CxBt, state.SBt = F.transform(
348+
state.CBt, to_order=formatB, transpose=True
349+
)
311350
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
312-
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
351+
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(
352+
ctx.grad_shape
353+
)
313354

314355
return grad_A, grad_B, None, None, None, None, None
315356

316357

317358
matmul = MatMul8bitLt.apply
318359

319360

320-
def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0):
361+
def matmul(
362+
A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0
363+
):
321364
state = state or MatmulLtState()
322365
if threshold > 0.0:
323366
state.threshold = threshold
324367
return MatMul8bitLt.apply(A, B, out, state)
325-

bitsandbytes/cextension.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
11
import ctypes as ct
22
import os
33
from warnings import warn
4+
45
from bitsandbytes.cuda_setup import evaluate_cuda_setup
56

67

78
class CUDALibrary_Singleton(object):
89
_instance = None
910

1011
def __init__(self):
11-
raise RuntimeError('Call get_instance() instead')
12+
raise RuntimeError("Call get_instance() instead")
1213

1314
def initialize(self):
1415
self.context = {}
1516
binary_name = evaluate_cuda_setup()
16-
if not os.path.exists(os.path.dirname(__file__) + f'/{binary_name}'):
17-
print(f'TODO: compile library for specific version: {binary_name}')
18-
print('defaulting to libbitsandbytes.so')
19-
self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
17+
if not os.path.exists(os.path.dirname(__file__) + f"/{binary_name}"):
18+
print(f"TODO: compile library for specific version: {binary_name}")
19+
print("defaulting to libbitsandbytes.so")
20+
self.lib = ct.cdll.LoadLibrary(
21+
os.path.dirname(__file__) + "/libbitsandbytes.so"
22+
)
2023
else:
21-
self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + f'/{binary_name}')
24+
self.lib = ct.cdll.LoadLibrary(
25+
os.path.dirname(__file__) + f"/{binary_name}"
26+
)
2227

2328
@classmethod
2429
def get_instance(cls):
@@ -35,6 +40,8 @@ def get_instance(cls):
3540
lib.get_cusparse.restype = ct.c_void_p
3641
COMPILED_WITH_CUDA = True
3742
except AttributeError:
38-
warn("The installed version of bitsandbytes was compiled without GPU support. "
39-
"8-bit optimizers and GPU quantization are unavailable.")
43+
warn(
44+
"The installed version of bitsandbytes was compiled without GPU support. "
45+
"8-bit optimizers and GPU quantization are unavailable."
46+
)
4047
COMPILED_WITH_CUDA = False

0 commit comments

Comments
 (0)