1+ from dataclasses import dataclass
2+
13import torch
4+
25import bitsandbytes as bnb
36import bitsandbytes .functional as F
47
5- from dataclasses import dataclass
6-
78tensor = torch .Tensor
89
9- '''
10+ """
1011 This class pools outlier dimensions across layers.
1112 This is particularly important for small models where outlier features
1213 are less systematic and occur with low frequency.
13- '''
14+ """
15+
16+
1417class GlobalOutlierPooler (object ):
1518 _instance = None
1619
1720 def __init__ (self ):
18- raise RuntimeError (' Call get_instance() instead' )
21+ raise RuntimeError (" Call get_instance() instead" )
1922
2023 def initialize (self ):
2124 self .outliers = set ()
@@ -29,25 +32,29 @@ def get_instance(cls):
2932 return cls ._instance
3033
3134 def add_outliers (self , outlier_idx , feature_dim ):
32- if self .model_dim is None : self .model_dim = feature_dim
33- if feature_dim != self .model_dim : return # we do not encode outliers for the 2nd FFN layer
35+ if self .model_dim is None :
36+ self .model_dim = feature_dim
37+ if feature_dim != self .model_dim :
38+ return # we do not encode outliers for the 2nd FFN layer
3439
3540 self .outliers .update (outlier_idx .tolist ())
3641
3742 def get_current_outlier_idx (self ):
3843 return torch .Tensor (list (self .outliers )).to (torch .int64 )
3944
40- class MatMul8bit (torch .autograd .Function ):
4145
46+ class MatMul8bit (torch .autograd .Function ):
4247 @staticmethod
43- def forward (ctx , A , B , out = None , quant_type = ' vector' , precision = [8 , 8 , 8 ]):
48+ def forward (ctx , A , B , out = None , quant_type = " vector" , precision = [8 , 8 , 8 ]):
4449
4550 if precision [0 ] != 8 :
4651 with torch .no_grad ():
4752 output = torch .matmul (A , B )
4853 else :
49- if len (B .shape ) == 2 : dim = 0
50- else : dim = 1
54+ if len (B .shape ) == 2 :
55+ dim = 0
56+ else :
57+ dim = 1
5158 qA , SA = F .vectorwise_quant (A , dim = - 1 , quant_type = quant_type )
5259 qB , SB = F .vectorwise_quant (B , dim = dim , quant_type = quant_type )
5360 iout = F .igemm (qA , qB )
@@ -84,21 +91,41 @@ def backward(ctx, grad_output):
8491 else :
8592 if len (B .shape ) == 2 and len (A .shape ) == 3 :
8693 grad_output = grad_output .contiguous ()
87- if not grad_output .is_contiguous (): grad_output .contiguous ()
88- qgrad_output , S1 = F .vectorwise_quant (grad_output .view (- 1 , grad_output .shape [2 ]), dim = 0 , quant_type = quant_type )
89- if not A .is_contiguous (): A = A .contiguous ()
90- qA , S2 = F .vectorwise_quant (A .view (- 1 , A .shape [2 ]), dim = 0 , quant_type = quant_type )
94+ if not grad_output .is_contiguous ():
95+ grad_output .contiguous ()
96+ qgrad_output , S1 = F .vectorwise_quant (
97+ grad_output .view (- 1 , grad_output .shape [2 ]),
98+ dim = 0 ,
99+ quant_type = quant_type ,
100+ )
101+ if not A .is_contiguous ():
102+ A = A .contiguous ()
103+ qA , S2 = F .vectorwise_quant (
104+ A .view (- 1 , A .shape [2 ]), dim = 0 , quant_type = quant_type
105+ )
91106 igrad_B = F .igemm (qA .t (), qgrad_output )
92- grad_B = F .vectorwise_mm_dequant (igrad_B , S2 .t (), S1 , grad_output .dtype , quant_type )
107+ grad_B = F .vectorwise_mm_dequant (
108+ igrad_B , S2 .t (), S1 , grad_output .dtype , quant_type
109+ )
93110 else :
94- qgrad_output , S1 = F .vectorwise_quant (grad_output , dim = dims , quant_type = quant_type )
111+ qgrad_output , S1 = F .vectorwise_quant (
112+ grad_output , dim = dims , quant_type = quant_type
113+ )
95114 qA , S2 = F .vectorwise_quant (A , dim = dims , quant_type = quant_type )
96115 igrad_B = F .igemm (qA .permute (permute_dim ), qgrad_output )
97- grad_B = F .vectorwise_mm_dequant (igrad_B , S2 .permute (permute_dim ), S1 , grad_output .dtype , quant_type )
116+ grad_B = F .vectorwise_mm_dequant (
117+ igrad_B ,
118+ S2 .permute (permute_dim ),
119+ S1 ,
120+ grad_output .dtype ,
121+ quant_type ,
122+ )
98123
99124 if A .requires_grad :
100- if len (grad_output .shape ) == 3 : dims = [2 ]
101- else : dims = [1 ]
125+ if len (grad_output .shape ) == 3 :
126+ dims = [2 ]
127+ else :
128+ dims = [1 ]
102129
103130 if len (B .shape ) == 3 :
104131 # bio -> boi
@@ -113,10 +140,14 @@ def backward(ctx, grad_output):
113140 with torch .no_grad ():
114141 grad_A = torch .matmul (grad_output , B .permute (permute_dim ))
115142 else :
116- qgrad_output , S1 = F .vectorwise_quant (grad_output , dim = dims , quant_type = quant_type )
143+ qgrad_output , S1 = F .vectorwise_quant (
144+ grad_output , dim = dims , quant_type = quant_type
145+ )
117146 qB , S3 = F .vectorwise_quant (B , dim = dim_B , quant_type = quant_type )
118147 igrad_A = F .igemm (qgrad_output , qB .permute (permute_dim ))
119- grad_A = F .vectorwise_mm_dequant (igrad_A , S1 , S3 .permute (permute_dim ), grad_output .dtype , quant_type )
148+ grad_A = F .vectorwise_mm_dequant (
149+ igrad_A , S1 , S3 .permute (permute_dim ), grad_output .dtype , quant_type
150+ )
120151
121152 return grad_A , grad_B , None , None , None
122153
@@ -125,6 +156,7 @@ def backward(ctx, grad_output):
125156bmm_cublas = MatMul8bit .apply
126157matmul_cublas = MatMul8bit .apply
127158
159+
128160@dataclass
129161class MatmulLtState :
130162 CB = None
@@ -159,7 +191,6 @@ def reset_grads(self):
159191
160192
161193class MatMul8bitLt (torch .autograd .Function ):
162-
163194 @staticmethod
164195 def forward (ctx , A , B , out = None , state = MatmulLtState ()):
165196 # 1. Quantize A
@@ -171,11 +202,15 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
171202 requires_gradB = B .requires_grad
172203 formatB = state .formatB
173204 input_shape = A .shape
174- if state .outlier_pool is None : state .outlier_pool = GlobalOutlierPooler .get_instance ()
175- assert A .dtype == torch .float16 , f'The input data type needs to be fp16 but { A .dtype } was found!'
205+ if state .outlier_pool is None :
206+ state .outlier_pool = GlobalOutlierPooler .get_instance ()
207+ assert (
208+ A .dtype == torch .float16
209+ ), f"The input data type needs to be fp16 but { A .dtype } was found!"
176210
177211 # 1. Quantize A
178- if len (A .shape ) == 3 : A = A .view (- 1 , A .shape [- 1 ]).contiguous ()
212+ if len (A .shape ) == 3 :
213+ A = A .view (- 1 , A .shape [- 1 ]).contiguous ()
179214 CA , CAt , SCA , SCAt , coo_tensorA = F .double_quant (A , threshold = state .threshold )
180215
181216 if state .threshold > 0.0 and coo_tensorA is not None :
@@ -191,8 +226,8 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
191226 # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
192227 # we also need to convert it to the turing/ampere format
193228 state .CxB , state .SB = F .transform (state .CB , to_order = formatB )
194- #state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
195- #if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
229+ # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
230+ # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
196231 # # generate outlier index and subB
197232 # outlier_idx = torch.unique(coo_tensorA.colidx).long()
198233 # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
@@ -203,24 +238,24 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
203238 # state.idx = outlier_idx
204239 # state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
205240
206- #if state.idx is not None:
241+ # if state.idx is not None:
207242 # # extract outliers
208243 # CA[:, state.idx] = 0
209244 # CAt[:, state.idx] = 0
210245 # subA = A[:, state.idx]
211- #else:
246+ # else:
212247 # subA = None
213248 else :
214249 if not state .has_fp16_weights and state .CxB is None :
215250 state .CxB , state .SB = F .transform (state .CB , to_order = formatB )
216251 subA = None
217252
218-
219253 # 2. Quantize B
220254 if state .has_fp16_weights :
221- has_grad = ( True if (getattr (B , ' grad' , None ) is not None ) else False )
255+ has_grad = True if (getattr (B , " grad" , None ) is not None ) else False
222256 is_transposed = not B .is_contiguous () and B .shape [0 ] == B .stride (1 )
223- if is_transposed : B = B .contiguous ()
257+ if is_transposed :
258+ B = B .contiguous ()
224259
225260 if (state .is_training and not has_grad ) or state .CxB is None :
226261 state .reset_grads ()
@@ -234,14 +269,16 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
234269
235270 outlier_idx = torch .unique (coo_tensorA .colidx )
236271 state .idx = outlier_idx
237- #state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
238- #if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
272+ # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
273+ # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
239274 # # do not use pool for 2nd FFN layer
240275 # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
241- #else:
276+ # else:
242277 # state.idx = outlier_idx
243278 outliers = F .extract_outliers (state .CxB , state .SB , state .idx .int ())
244- state .subB = (outliers * state .SCB .view (- 1 , 1 )/ 127.0 ).t ().contiguous ().half ()
279+ state .subB = (
280+ (outliers * state .SCB .view (- 1 , 1 ) / 127.0 ).t ().contiguous ().half ()
281+ )
245282 CA [:, state .idx .long ()] = 0
246283 CAt [:, state .idx .long ()] = 0
247284 subA = A [:, state .idx .long ()]
@@ -254,7 +291,7 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
254291 output_shape = (input_shape [0 ], shapeB [0 ])
255292
256293 # 3. Matmul
257- C32A , SA = F .transform (CA , ' col32' )
294+ C32A , SA = F .transform (CA , " col32" )
258295 out32 , Sout32 = F .igemmlt (C32A , state .CxB , SA , state .SB )
259296 output = F .mm_dequant (out32 , Sout32 , SCA , state .SCB )
260297
@@ -277,7 +314,7 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
277314 ctx .tensor_states = (None , None )
278315 ctx .save_for_backward (None , None )
279316
280- #clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
317+ # clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
281318 clone_func = torch .clone
282319 return clone_func (output .view (output_shape ))
283320
@@ -288,7 +325,7 @@ def backward(ctx, grad_output):
288325 SCAt , idx = ctx .tensor_states
289326 formatB = ctx .formatB
290327 state = ctx .state
291- assert state .has_fp16_weights , ' Backprop only supported for fp16 weights.'
328+ assert state .has_fp16_weights , " Backprop only supported for fp16 weights."
292329
293330 if len (grad_output .shape ) == 3 :
294331 grad_output = grad_output .view (- 1 , grad_output .shape [- 1 ]).contiguous ()
@@ -298,28 +335,33 @@ def backward(ctx, grad_output):
298335 Cgrad , Cgradt , SCgrad , SCgradt , coo_tensor = F .double_quant (grad_output )
299336 if req_gradB :
300337 CxAt , SAt = F .transform (CAt , formatB , transpose = True )
301- C32grad , Sgrad = F .transform (Cgradt , ' col32' , transpose = True )
338+ C32grad , Sgrad = F .transform (Cgradt , " col32" , transpose = True )
302339 gradB32 , SgradB32 = F .igemmlt (C32grad , CxAt , Sgrad , SAt )
303340 grad_B = F .mm_dequant (gradB32 , SgradB32 , SCgradt , SCAt )
304341 if state .threshold > 0.0 and subA is not None :
305342 grad_B [:, idx ] += torch .matmul (grad_output .t (), subA )
306343
307344 if req_gradA :
308- C32grad , Sgrad = F .transform (Cgrad , ' col32' )
345+ C32grad , Sgrad = F .transform (Cgrad , " col32" )
309346 if state .CxBt is None :
310- state .CxBt , state .SBt = F .transform (state .CBt , to_order = formatB , transpose = True )
347+ state .CxBt , state .SBt = F .transform (
348+ state .CBt , to_order = formatB , transpose = True
349+ )
311350 gradA32 , SgradA32 = F .igemmlt (C32grad , state .CxBt , Sgrad , state .SBt )
312- grad_A = F .mm_dequant (gradA32 , SgradA32 , SCgrad , state .SCBt ).view (ctx .grad_shape )
351+ grad_A = F .mm_dequant (gradA32 , SgradA32 , SCgrad , state .SCBt ).view (
352+ ctx .grad_shape
353+ )
313354
314355 return grad_A , grad_B , None , None , None , None , None
315356
316357
317358matmul = MatMul8bitLt .apply
318359
319360
320- def matmul (A : tensor , B : tensor , out : tensor = None , state : MatmulLtState = None , threshold = 0.0 ):
361+ def matmul (
362+ A : tensor , B : tensor , out : tensor = None , state : MatmulLtState = None , threshold = 0.0
363+ ):
321364 state = state or MatmulLtState ()
322365 if threshold > 0.0 :
323366 state .threshold = threshold
324367 return MatMul8bitLt .apply (A , B , out , state )
325-
0 commit comments