1
+ from dataclasses import dataclass
2
+
1
3
import torch
4
+
2
5
import bitsandbytes as bnb
3
6
import bitsandbytes .functional as F
4
7
5
- from dataclasses import dataclass
6
-
7
8
tensor = torch .Tensor
8
9
9
- '''
10
+ """
10
11
This class pools outlier dimensions across layers.
11
12
This is particularly important for small models where outlier features
12
13
are less systematic and occur with low frequency.
13
- '''
14
+ """
15
+
16
+
14
17
class GlobalOutlierPooler (object ):
15
18
_instance = None
16
19
17
20
def __init__ (self ):
18
- raise RuntimeError (' Call get_instance() instead' )
21
+ raise RuntimeError (" Call get_instance() instead" )
19
22
20
23
def initialize (self ):
21
24
self .outliers = set ()
@@ -29,25 +32,29 @@ def get_instance(cls):
29
32
return cls ._instance
30
33
31
34
def add_outliers (self , outlier_idx , feature_dim ):
32
- if self .model_dim is None : self .model_dim = feature_dim
33
- if feature_dim != self .model_dim : return # we do not encode outliers for the 2nd FFN layer
35
+ if self .model_dim is None :
36
+ self .model_dim = feature_dim
37
+ if feature_dim != self .model_dim :
38
+ return # we do not encode outliers for the 2nd FFN layer
34
39
35
40
self .outliers .update (outlier_idx .tolist ())
36
41
37
42
def get_current_outlier_idx (self ):
38
43
return torch .Tensor (list (self .outliers )).to (torch .int64 )
39
44
40
- class MatMul8bit (torch .autograd .Function ):
41
45
46
+ class MatMul8bit (torch .autograd .Function ):
42
47
@staticmethod
43
- def forward (ctx , A , B , out = None , quant_type = ' vector' , precision = [8 , 8 , 8 ]):
48
+ def forward (ctx , A , B , out = None , quant_type = " vector" , precision = [8 , 8 , 8 ]):
44
49
45
50
if precision [0 ] != 8 :
46
51
with torch .no_grad ():
47
52
output = torch .matmul (A , B )
48
53
else :
49
- if len (B .shape ) == 2 : dim = 0
50
- else : dim = 1
54
+ if len (B .shape ) == 2 :
55
+ dim = 0
56
+ else :
57
+ dim = 1
51
58
qA , SA = F .vectorwise_quant (A , dim = - 1 , quant_type = quant_type )
52
59
qB , SB = F .vectorwise_quant (B , dim = dim , quant_type = quant_type )
53
60
iout = F .igemm (qA , qB )
@@ -84,21 +91,41 @@ def backward(ctx, grad_output):
84
91
else :
85
92
if len (B .shape ) == 2 and len (A .shape ) == 3 :
86
93
grad_output = grad_output .contiguous ()
87
- if not grad_output .is_contiguous (): grad_output .contiguous ()
88
- qgrad_output , S1 = F .vectorwise_quant (grad_output .view (- 1 , grad_output .shape [2 ]), dim = 0 , quant_type = quant_type )
89
- if not A .is_contiguous (): A = A .contiguous ()
90
- qA , S2 = F .vectorwise_quant (A .view (- 1 , A .shape [2 ]), dim = 0 , quant_type = quant_type )
94
+ if not grad_output .is_contiguous ():
95
+ grad_output .contiguous ()
96
+ qgrad_output , S1 = F .vectorwise_quant (
97
+ grad_output .view (- 1 , grad_output .shape [2 ]),
98
+ dim = 0 ,
99
+ quant_type = quant_type ,
100
+ )
101
+ if not A .is_contiguous ():
102
+ A = A .contiguous ()
103
+ qA , S2 = F .vectorwise_quant (
104
+ A .view (- 1 , A .shape [2 ]), dim = 0 , quant_type = quant_type
105
+ )
91
106
igrad_B = F .igemm (qA .t (), qgrad_output )
92
- grad_B = F .vectorwise_mm_dequant (igrad_B , S2 .t (), S1 , grad_output .dtype , quant_type )
107
+ grad_B = F .vectorwise_mm_dequant (
108
+ igrad_B , S2 .t (), S1 , grad_output .dtype , quant_type
109
+ )
93
110
else :
94
- qgrad_output , S1 = F .vectorwise_quant (grad_output , dim = dims , quant_type = quant_type )
111
+ qgrad_output , S1 = F .vectorwise_quant (
112
+ grad_output , dim = dims , quant_type = quant_type
113
+ )
95
114
qA , S2 = F .vectorwise_quant (A , dim = dims , quant_type = quant_type )
96
115
igrad_B = F .igemm (qA .permute (permute_dim ), qgrad_output )
97
- grad_B = F .vectorwise_mm_dequant (igrad_B , S2 .permute (permute_dim ), S1 , grad_output .dtype , quant_type )
116
+ grad_B = F .vectorwise_mm_dequant (
117
+ igrad_B ,
118
+ S2 .permute (permute_dim ),
119
+ S1 ,
120
+ grad_output .dtype ,
121
+ quant_type ,
122
+ )
98
123
99
124
if A .requires_grad :
100
- if len (grad_output .shape ) == 3 : dims = [2 ]
101
- else : dims = [1 ]
125
+ if len (grad_output .shape ) == 3 :
126
+ dims = [2 ]
127
+ else :
128
+ dims = [1 ]
102
129
103
130
if len (B .shape ) == 3 :
104
131
# bio -> boi
@@ -113,10 +140,14 @@ def backward(ctx, grad_output):
113
140
with torch .no_grad ():
114
141
grad_A = torch .matmul (grad_output , B .permute (permute_dim ))
115
142
else :
116
- qgrad_output , S1 = F .vectorwise_quant (grad_output , dim = dims , quant_type = quant_type )
143
+ qgrad_output , S1 = F .vectorwise_quant (
144
+ grad_output , dim = dims , quant_type = quant_type
145
+ )
117
146
qB , S3 = F .vectorwise_quant (B , dim = dim_B , quant_type = quant_type )
118
147
igrad_A = F .igemm (qgrad_output , qB .permute (permute_dim ))
119
- grad_A = F .vectorwise_mm_dequant (igrad_A , S1 , S3 .permute (permute_dim ), grad_output .dtype , quant_type )
148
+ grad_A = F .vectorwise_mm_dequant (
149
+ igrad_A , S1 , S3 .permute (permute_dim ), grad_output .dtype , quant_type
150
+ )
120
151
121
152
return grad_A , grad_B , None , None , None
122
153
@@ -125,6 +156,7 @@ def backward(ctx, grad_output):
125
156
bmm_cublas = MatMul8bit .apply
126
157
matmul_cublas = MatMul8bit .apply
127
158
159
+
128
160
@dataclass
129
161
class MatmulLtState :
130
162
CB = None
@@ -159,7 +191,6 @@ def reset_grads(self):
159
191
160
192
161
193
class MatMul8bitLt (torch .autograd .Function ):
162
-
163
194
@staticmethod
164
195
def forward (ctx , A , B , out = None , state = MatmulLtState ()):
165
196
# 1. Quantize A
@@ -171,11 +202,15 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
171
202
requires_gradB = B .requires_grad
172
203
formatB = state .formatB
173
204
input_shape = A .shape
174
- if state .outlier_pool is None : state .outlier_pool = GlobalOutlierPooler .get_instance ()
175
- assert A .dtype == torch .float16 , f'The input data type needs to be fp16 but { A .dtype } was found!'
205
+ if state .outlier_pool is None :
206
+ state .outlier_pool = GlobalOutlierPooler .get_instance ()
207
+ assert (
208
+ A .dtype == torch .float16
209
+ ), f"The input data type needs to be fp16 but { A .dtype } was found!"
176
210
177
211
# 1. Quantize A
178
- if len (A .shape ) == 3 : A = A .view (- 1 , A .shape [- 1 ]).contiguous ()
212
+ if len (A .shape ) == 3 :
213
+ A = A .view (- 1 , A .shape [- 1 ]).contiguous ()
179
214
CA , CAt , SCA , SCAt , coo_tensorA = F .double_quant (A , threshold = state .threshold )
180
215
181
216
if state .threshold > 0.0 and coo_tensorA is not None :
@@ -191,8 +226,8 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
191
226
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
192
227
# we also need to convert it to the turing/ampere format
193
228
state .CxB , state .SB = F .transform (state .CB , to_order = formatB )
194
- #state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
195
- #if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
229
+ # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
230
+ # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
196
231
# # generate outlier index and subB
197
232
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
198
233
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
@@ -203,24 +238,24 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
203
238
# state.idx = outlier_idx
204
239
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
205
240
206
- #if state.idx is not None:
241
+ # if state.idx is not None:
207
242
# # extract outliers
208
243
# CA[:, state.idx] = 0
209
244
# CAt[:, state.idx] = 0
210
245
# subA = A[:, state.idx]
211
- #else:
246
+ # else:
212
247
# subA = None
213
248
else :
214
249
if not state .has_fp16_weights and state .CxB is None :
215
250
state .CxB , state .SB = F .transform (state .CB , to_order = formatB )
216
251
subA = None
217
252
218
-
219
253
# 2. Quantize B
220
254
if state .has_fp16_weights :
221
- has_grad = ( True if (getattr (B , ' grad' , None ) is not None ) else False )
255
+ has_grad = True if (getattr (B , " grad" , None ) is not None ) else False
222
256
is_transposed = not B .is_contiguous () and B .shape [0 ] == B .stride (1 )
223
- if is_transposed : B = B .contiguous ()
257
+ if is_transposed :
258
+ B = B .contiguous ()
224
259
225
260
if (state .is_training and not has_grad ) or state .CxB is None :
226
261
state .reset_grads ()
@@ -234,14 +269,16 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
234
269
235
270
outlier_idx = torch .unique (coo_tensorA .colidx )
236
271
state .idx = outlier_idx
237
- #state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
238
- #if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
272
+ # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
273
+ # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
239
274
# # do not use pool for 2nd FFN layer
240
275
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
241
- #else:
276
+ # else:
242
277
# state.idx = outlier_idx
243
278
outliers = F .extract_outliers (state .CxB , state .SB , state .idx .int ())
244
- state .subB = (outliers * state .SCB .view (- 1 , 1 )/ 127.0 ).t ().contiguous ().half ()
279
+ state .subB = (
280
+ (outliers * state .SCB .view (- 1 , 1 ) / 127.0 ).t ().contiguous ().half ()
281
+ )
245
282
CA [:, state .idx .long ()] = 0
246
283
CAt [:, state .idx .long ()] = 0
247
284
subA = A [:, state .idx .long ()]
@@ -254,7 +291,7 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
254
291
output_shape = (input_shape [0 ], shapeB [0 ])
255
292
256
293
# 3. Matmul
257
- C32A , SA = F .transform (CA , ' col32' )
294
+ C32A , SA = F .transform (CA , " col32" )
258
295
out32 , Sout32 = F .igemmlt (C32A , state .CxB , SA , state .SB )
259
296
output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB )
260
297
@@ -277,7 +314,7 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
277
314
ctx .tensor_states = (None , None )
278
315
ctx .save_for_backward (None , None )
279
316
280
- #clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
317
+ # clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
281
318
clone_func = torch .clone
282
319
return clone_func (output .view (output_shape ))
283
320
@@ -288,7 +325,7 @@ def backward(ctx, grad_output):
288
325
SCAt , idx = ctx .tensor_states
289
326
formatB = ctx .formatB
290
327
state = ctx .state
291
- assert state .has_fp16_weights , ' Backprop only supported for fp16 weights.'
328
+ assert state .has_fp16_weights , " Backprop only supported for fp16 weights."
292
329
293
330
if len (grad_output .shape ) == 3 :
294
331
grad_output = grad_output .view (- 1 , grad_output .shape [- 1 ]).contiguous ()
@@ -298,28 +335,33 @@ def backward(ctx, grad_output):
298
335
Cgrad , Cgradt , SCgrad , SCgradt , coo_tensor = F .double_quant (grad_output )
299
336
if req_gradB :
300
337
CxAt , SAt = F .transform (CAt , formatB , transpose = True )
301
- C32grad , Sgrad = F .transform (Cgradt , ' col32' , transpose = True )
338
+ C32grad , Sgrad = F .transform (Cgradt , " col32" , transpose = True )
302
339
gradB32 , SgradB32 = F .igemmlt (C32grad , CxAt , Sgrad , SAt )
303
340
grad_B = F .mm_dequant (gradB32 , SgradB32 , SCgradt , SCAt )
304
341
if state .threshold > 0.0 and subA is not None :
305
342
grad_B [:, idx ] += torch .matmul (grad_output .t (), subA )
306
343
307
344
if req_gradA :
308
- C32grad , Sgrad = F .transform (Cgrad , ' col32' )
345
+ C32grad , Sgrad = F .transform (Cgrad , " col32" )
309
346
if state .CxBt is None :
310
- state .CxBt , state .SBt = F .transform (state .CBt , to_order = formatB , transpose = True )
347
+ state .CxBt , state .SBt = F .transform (
348
+ state .CBt , to_order = formatB , transpose = True
349
+ )
311
350
gradA32 , SgradA32 = F .igemmlt (C32grad , state .CxBt , Sgrad , state .SBt )
312
- grad_A = F .mm_dequant (gradA32 , SgradA32 , SCgrad , state .SCBt ).view (ctx .grad_shape )
351
+ grad_A = F .mm_dequant (gradA32 , SgradA32 , SCgrad , state .SCBt ).view (
352
+ ctx .grad_shape
353
+ )
313
354
314
355
return grad_A , grad_B , None , None , None , None , None
315
356
316
357
317
358
matmul = MatMul8bitLt .apply
318
359
319
360
320
- def matmul (A : tensor , B : tensor , out : tensor = None , state : MatmulLtState = None , threshold = 0.0 ):
361
+ def matmul (
362
+ A : tensor , B : tensor , out : tensor = None , state : MatmulLtState = None , threshold = 0.0
363
+ ):
321
364
state = state or MatmulLtState ()
322
365
if threshold > 0.0 :
323
366
state .threshold = threshold
324
367
return MatMul8bitLt .apply (A , B , out , state )
325
-
0 commit comments