55def replace_all_linear_with_lora (module , rank : int , scaling : float , device = None , dtype = None ):
66 """ Recursively replace all Linear layers with LoRALinear layers."""
77 for name , child in module .named_children ():
8- if isinstance (child , nn .Linear ):
9- if device is None :
10- this_device = child .weight .device
11- else :
12- this_device = device
13- if dtype is None :
14- this_dtype = child .weight .dtype
15- else :
16- this_dtype = dtype
17- lora = LoRALinear (child .in_features , child .out_features ,
8+ # Check for both nn.Linear and QLinear (from quantize.py)
9+ if isinstance (child , nn .Linear ) or (hasattr (child , 'weight' ) and hasattr (child , 'weight_scb' )):
10+ # For QLinear, we need to get in_features and out_features differently
11+ if isinstance (child , nn .Linear ):
12+ in_features = child .in_features
13+ out_features = child .out_features
14+ if device is None :
15+ this_device = child .weight .device
16+ else :
17+ this_device = device
18+ if dtype is None :
19+ this_dtype = child .weight .dtype
20+ else :
21+ this_dtype = dtype
22+ else : # QLinear
23+ # For QLinear, we can infer dimensions from weight shape
24+ # weight is [out_features, in_features] for both Linear and QLinear
25+ if hasattr (child , 'weight' ) and child .weight .shape :
26+ if child .weight .device .type != 'meta' :
27+ out_features = child .weight .shape [0 ]
28+ in_features = child .weight .shape [1 ]
29+ else :
30+ # For meta tensors, we need to be careful about padded dimensions
31+ # QLinear pads out_features to multiple of 8
32+ out_features = child .weight_scb .shape [0 ] # This is the actual out_features
33+ in_features = child .weight .shape [1 ]
34+ else :
35+ # If we can't determine the shape, skip this layer
36+ continue
37+
38+ if device is None :
39+ this_device = child .weight .device
40+ else :
41+ this_device = device
42+ if dtype is None :
43+ # For QLinear, we should use float16 or bfloat16 for LoRA
44+ this_dtype = torch .bfloat16 if dtype is None else dtype
45+ else :
46+ this_dtype = dtype
47+
48+ lora = LoRALinear (in_features , out_features ,
1849 rank , scaling , device = this_device , dtype = this_dtype )
1950 lora .frozen_W = child
2051 setattr (module , name , lora )
@@ -26,17 +57,57 @@ def replace_lora_with_linear(module):
2657 """Recursively replace all LoRALinear layers with Linear layers."""
2758 for name , child in module .named_children ():
2859 if isinstance (child , LoRALinear ):
29- # Compute merged weights: W' = W + scaling * B @ A
30- merged_weight = child .frozen_W .weight .data + \
31- child .scaling * (child .lora_B .weight @ child .lora_A .weight )
32- # Create a standard Linear layer with the same in/out features
33- new_linear = nn .Linear (child .frozen_W .in_features ,
34- child .frozen_W .out_features , bias = False ,
35- device = torch .device ('meta' ),
36- dtype = merged_weight .dtype )
37- new_linear .weight = nn .Parameter (
38- merged_weight , requires_grad = merged_weight .requires_grad ) # Transfer merged weights
39- setattr (module , name , new_linear ) # Replace the module
60+ # Check if frozen_W is a QLinear or nn.Linear
61+ if hasattr (child .frozen_W , 'weight_scb' ):
62+ # For QLinear, we need to convert back to nn.Linear
63+ # This is because QLinear uses int8 quantization which isn't compatible with
64+ # directly adding the LoRA weights
65+
66+ # First, compute the LoRA contribution
67+ lora_contribution = child .scaling * (child .lora_B .weight @ child .lora_A .weight )
68+
69+ # Create a standard Linear layer with the same in/out features
70+ new_linear = nn .Linear (child .in_features ,
71+ child .out_features , bias = False ,
72+ device = torch .device ('meta' ),
73+ dtype = lora_contribution .dtype )
74+
75+ # For QLinear, we need to run a forward pass to get the dequantized weights
76+ # This is a workaround since we can't directly access the dequantized weights
77+ # We'll create a dummy input and extract the weights from the output
78+ with torch .no_grad ():
79+ # Create identity matrix as input to extract the weight matrix
80+ dummy_input = torch .eye (
81+ child .in_features ,
82+ device = lora_contribution .device ,
83+ dtype = torch .float16 # QLinear expects float16
84+ )
85+ # Get the output which is equivalent to the weight matrix
86+ dequantized_weight = child .frozen_W (dummy_input )
87+ # Convert to the same dtype as lora_contribution
88+ dequantized_weight = dequantized_weight .to (lora_contribution .dtype )
89+ # Transpose the dequantized weight to match the shape of lora_contribution
90+ dequantized_weight = dequantized_weight .transpose (0 , 1 )
91+ # Add the LoRA contribution
92+ merged_weight = dequantized_weight + lora_contribution
93+ # Set the merged weights
94+ new_linear .weight = nn .Parameter (
95+ merged_weight , requires_grad = False )
96+
97+ setattr (module , name , new_linear ) # Replace the module
98+ else :
99+ # Standard nn.Linear case
100+ # Compute merged weights: W' = W + scaling * B @ A
101+ merged_weight = child .frozen_W .weight .data + \
102+ child .scaling * (child .lora_B .weight @ child .lora_A .weight )
103+ # Create a standard Linear layer with the same in/out features
104+ new_linear = nn .Linear (child .frozen_W .in_features ,
105+ child .frozen_W .out_features , bias = False ,
106+ device = torch .device ('meta' ),
107+ dtype = merged_weight .dtype )
108+ new_linear .weight = nn .Parameter (
109+ merged_weight , requires_grad = merged_weight .requires_grad ) # Transfer merged weights
110+ setattr (module , name , new_linear ) # Replace the module
40111 else :
41112 replace_lora_with_linear (child ) # Recursively process submodules
42113
@@ -103,19 +174,68 @@ def merge_weight(self):
103174
104175 weight = up_weight .mm (down_weight ) * self .scaling
105176
106- weight += self .frozen_W .weight
177+ # Handle both nn.Linear and QLinear for frozen_W
178+ if isinstance (self .frozen_W , nn .Linear ):
179+ # Standard nn.Linear case
180+ weight += self .frozen_W .weight
181+ elif hasattr (self .frozen_W , 'weight_scb' ):
182+ # For QLinear, we need to run a forward pass to get the dequantized weights
183+ # Create identity matrix as input to extract the weight matrix
184+ dummy_input = torch .eye (
185+ self .in_features ,
186+ device = weight .device ,
187+ dtype = torch .float16 # QLinear expects float16
188+ )
189+ # Get the output which is equivalent to the weight matrix
190+ dequantized_weight = self .frozen_W (dummy_input )
191+ # Convert to the same dtype as weight
192+ dequantized_weight = dequantized_weight .to (weight .dtype )
193+ # Transpose the dequantized weight to match the shape of weight
194+ dequantized_weight = dequantized_weight .transpose (0 , 1 )
195+ # Add to the LoRA contribution
196+ weight += dequantized_weight
197+ else :
198+ # Fallback for any other type
199+ weight += self .frozen_W .weight
107200 return weight
108201
109202 @staticmethod
110203 def _load_hook (module , state_dict , prefix , * _ ):
111- key_name = prefix + "weight"
112- if key_name in state_dict :
113- w_ref = state_dict .pop (key_name )
114- state_dict [prefix + 'frozen_W.weight' ] = w_ref
204+ qlinear_params = ("weight" , "weight_scb" , "weight_absmax" ,
205+ "bias" , "bias_scb" , "bias_absmax" ) # add others if you use act-order
206+
207+ for name in qlinear_params :
208+ key = prefix + name
209+ if key in state_dict :
210+ state_dict [f"{ prefix } frozen_W.{ name } " ] = state_dict .pop (key )
115211
116212 def forward (self , x : torch .Tensor ):
117213 lora = self .lora_B (self .lora_A (x ))
118- return self .frozen_W (x ) + lora * self .scaling
214+
215+ # Handle both nn.Linear and QLinear for frozen_W
216+ if isinstance (self .frozen_W , nn .Linear ):
217+ # Standard nn.Linear forward
218+ return self .frozen_W (x ) + lora * self .scaling
219+ elif hasattr (self .frozen_W , 'weight_scb' ):
220+ # QLinear forward - we need to ensure dtype compatibility
221+ # QLinear expects float16 input and returns float16 output
222+ # LoRA adapters are in float16/bfloat16
223+ x_dtype = x .dtype
224+ if x_dtype != torch .float16 :
225+ x_for_frozen = x .to (torch .float16 )
226+ else :
227+ x_for_frozen = x
228+
229+ frozen_output = self .frozen_W (x_for_frozen )
230+
231+ # Convert back to original dtype if needed
232+ if frozen_output .dtype != x_dtype :
233+ frozen_output = frozen_output .to (x_dtype )
234+
235+ return frozen_output + lora * self .scaling
236+ else :
237+ # Fallback for any other type
238+ return self .frozen_W (x ) + lora * self .scaling
119239
120240 def __repr__ (self ) -> str :
121241 return "{}Linear(in_features={}, out_features={}, r={})" .format (
0 commit comments