Skip to content

Commit f113cbd

Browse files
committed
Adds LoRa fine tuning
1 parent 26dfb20 commit f113cbd

File tree

3 files changed

+159
-30
lines changed

3 files changed

+159
-30
lines changed

moshi/moshi/models/loaders.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,7 @@ def get_moshi_lm(
376376
model.load_state_dict(pkg["fsdp_best_state"]["model"], assign=True)
377377

378378
if lora:
379-
assert not lm_kwargs.get("quantize"), (
380-
"LoRA and quantization are incompatible for now."
381-
)
379+
# LoRA now supports quantized models
382380
model = get_lora_moshi(
383381
model=model,
384382
lora_rank=lora_rank,

moshi/moshi/modules/lora.py

Lines changed: 147 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,47 @@
55
def replace_all_linear_with_lora(module, rank: int, scaling: float, device=None, dtype=None):
66
""" Recursively replace all Linear layers with LoRALinear layers."""
77
for name, child in module.named_children():
8-
if isinstance(child, nn.Linear):
9-
if device is None:
10-
this_device = child.weight.device
11-
else:
12-
this_device = device
13-
if dtype is None:
14-
this_dtype = child.weight.dtype
15-
else:
16-
this_dtype = dtype
17-
lora = LoRALinear(child.in_features, child.out_features,
8+
# Check for both nn.Linear and QLinear (from quantize.py)
9+
if isinstance(child, nn.Linear) or (hasattr(child, 'weight') and hasattr(child, 'weight_scb')):
10+
# For QLinear, we need to get in_features and out_features differently
11+
if isinstance(child, nn.Linear):
12+
in_features = child.in_features
13+
out_features = child.out_features
14+
if device is None:
15+
this_device = child.weight.device
16+
else:
17+
this_device = device
18+
if dtype is None:
19+
this_dtype = child.weight.dtype
20+
else:
21+
this_dtype = dtype
22+
else: # QLinear
23+
# For QLinear, we can infer dimensions from weight shape
24+
# weight is [out_features, in_features] for both Linear and QLinear
25+
if hasattr(child, 'weight') and child.weight.shape:
26+
if child.weight.device.type != 'meta':
27+
out_features = child.weight.shape[0]
28+
in_features = child.weight.shape[1]
29+
else:
30+
# For meta tensors, we need to be careful about padded dimensions
31+
# QLinear pads out_features to multiple of 8
32+
out_features = child.weight_scb.shape[0] # This is the actual out_features
33+
in_features = child.weight.shape[1]
34+
else:
35+
# If we can't determine the shape, skip this layer
36+
continue
37+
38+
if device is None:
39+
this_device = child.weight.device
40+
else:
41+
this_device = device
42+
if dtype is None:
43+
# For QLinear, we should use float16 or bfloat16 for LoRA
44+
this_dtype = torch.bfloat16 if dtype is None else dtype
45+
else:
46+
this_dtype = dtype
47+
48+
lora = LoRALinear(in_features, out_features,
1849
rank, scaling, device=this_device, dtype=this_dtype)
1950
lora.frozen_W = child
2051
setattr(module, name, lora)
@@ -26,17 +57,57 @@ def replace_lora_with_linear(module):
2657
"""Recursively replace all LoRALinear layers with Linear layers."""
2758
for name, child in module.named_children():
2859
if isinstance(child, LoRALinear):
29-
# Compute merged weights: W' = W + scaling * B @ A
30-
merged_weight = child.frozen_W.weight.data + \
31-
child.scaling * (child.lora_B.weight @ child.lora_A.weight)
32-
# Create a standard Linear layer with the same in/out features
33-
new_linear = nn.Linear(child.frozen_W.in_features,
34-
child.frozen_W.out_features, bias=False,
35-
device=torch.device('meta'),
36-
dtype=merged_weight.dtype)
37-
new_linear.weight = nn.Parameter(
38-
merged_weight, requires_grad=merged_weight.requires_grad) # Transfer merged weights
39-
setattr(module, name, new_linear) # Replace the module
60+
# Check if frozen_W is a QLinear or nn.Linear
61+
if hasattr(child.frozen_W, 'weight_scb'):
62+
# For QLinear, we need to convert back to nn.Linear
63+
# This is because QLinear uses int8 quantization which isn't compatible with
64+
# directly adding the LoRA weights
65+
66+
# First, compute the LoRA contribution
67+
lora_contribution = child.scaling * (child.lora_B.weight @ child.lora_A.weight)
68+
69+
# Create a standard Linear layer with the same in/out features
70+
new_linear = nn.Linear(child.in_features,
71+
child.out_features, bias=False,
72+
device=torch.device('meta'),
73+
dtype=lora_contribution.dtype)
74+
75+
# For QLinear, we need to run a forward pass to get the dequantized weights
76+
# This is a workaround since we can't directly access the dequantized weights
77+
# We'll create a dummy input and extract the weights from the output
78+
with torch.no_grad():
79+
# Create identity matrix as input to extract the weight matrix
80+
dummy_input = torch.eye(
81+
child.in_features,
82+
device=lora_contribution.device,
83+
dtype=torch.float16 # QLinear expects float16
84+
)
85+
# Get the output which is equivalent to the weight matrix
86+
dequantized_weight = child.frozen_W(dummy_input)
87+
# Convert to the same dtype as lora_contribution
88+
dequantized_weight = dequantized_weight.to(lora_contribution.dtype)
89+
# Transpose the dequantized weight to match the shape of lora_contribution
90+
dequantized_weight = dequantized_weight.transpose(0, 1)
91+
# Add the LoRA contribution
92+
merged_weight = dequantized_weight + lora_contribution
93+
# Set the merged weights
94+
new_linear.weight = nn.Parameter(
95+
merged_weight, requires_grad=False)
96+
97+
setattr(module, name, new_linear) # Replace the module
98+
else:
99+
# Standard nn.Linear case
100+
# Compute merged weights: W' = W + scaling * B @ A
101+
merged_weight = child.frozen_W.weight.data + \
102+
child.scaling * (child.lora_B.weight @ child.lora_A.weight)
103+
# Create a standard Linear layer with the same in/out features
104+
new_linear = nn.Linear(child.frozen_W.in_features,
105+
child.frozen_W.out_features, bias=False,
106+
device=torch.device('meta'),
107+
dtype=merged_weight.dtype)
108+
new_linear.weight = nn.Parameter(
109+
merged_weight, requires_grad=merged_weight.requires_grad) # Transfer merged weights
110+
setattr(module, name, new_linear) # Replace the module
40111
else:
41112
replace_lora_with_linear(child) # Recursively process submodules
42113

@@ -103,19 +174,68 @@ def merge_weight(self):
103174

104175
weight = up_weight.mm(down_weight) * self.scaling
105176

106-
weight += self.frozen_W.weight
177+
# Handle both nn.Linear and QLinear for frozen_W
178+
if isinstance(self.frozen_W, nn.Linear):
179+
# Standard nn.Linear case
180+
weight += self.frozen_W.weight
181+
elif hasattr(self.frozen_W, 'weight_scb'):
182+
# For QLinear, we need to run a forward pass to get the dequantized weights
183+
# Create identity matrix as input to extract the weight matrix
184+
dummy_input = torch.eye(
185+
self.in_features,
186+
device=weight.device,
187+
dtype=torch.float16 # QLinear expects float16
188+
)
189+
# Get the output which is equivalent to the weight matrix
190+
dequantized_weight = self.frozen_W(dummy_input)
191+
# Convert to the same dtype as weight
192+
dequantized_weight = dequantized_weight.to(weight.dtype)
193+
# Transpose the dequantized weight to match the shape of weight
194+
dequantized_weight = dequantized_weight.transpose(0, 1)
195+
# Add to the LoRA contribution
196+
weight += dequantized_weight
197+
else:
198+
# Fallback for any other type
199+
weight += self.frozen_W.weight
107200
return weight
108201

109202
@staticmethod
110203
def _load_hook(module, state_dict, prefix, *_):
111-
key_name = prefix + "weight"
112-
if key_name in state_dict:
113-
w_ref = state_dict.pop(key_name)
114-
state_dict[prefix + 'frozen_W.weight'] = w_ref
204+
qlinear_params = ("weight", "weight_scb", "weight_absmax",
205+
"bias", "bias_scb", "bias_absmax") # add others if you use act-order
206+
207+
for name in qlinear_params:
208+
key = prefix + name
209+
if key in state_dict:
210+
state_dict[f"{prefix}frozen_W.{name}"] = state_dict.pop(key)
115211

116212
def forward(self, x: torch.Tensor):
117213
lora = self.lora_B(self.lora_A(x))
118-
return self.frozen_W(x) + lora * self.scaling
214+
215+
# Handle both nn.Linear and QLinear for frozen_W
216+
if isinstance(self.frozen_W, nn.Linear):
217+
# Standard nn.Linear forward
218+
return self.frozen_W(x) + lora * self.scaling
219+
elif hasattr(self.frozen_W, 'weight_scb'):
220+
# QLinear forward - we need to ensure dtype compatibility
221+
# QLinear expects float16 input and returns float16 output
222+
# LoRA adapters are in float16/bfloat16
223+
x_dtype = x.dtype
224+
if x_dtype != torch.float16:
225+
x_for_frozen = x.to(torch.float16)
226+
else:
227+
x_for_frozen = x
228+
229+
frozen_output = self.frozen_W(x_for_frozen)
230+
231+
# Convert back to original dtype if needed
232+
if frozen_output.dtype != x_dtype:
233+
frozen_output = frozen_output.to(x_dtype)
234+
235+
return frozen_output + lora * self.scaling
236+
else:
237+
# Fallback for any other type
238+
return self.frozen_W(x) + lora * self.scaling
119239

120240
def __repr__(self) -> str:
121241
return "{}Linear(in_features={}, out_features={}, r={})".format(

moshi/moshi/utils/quantize.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,16 @@ def replace_linear_with_qlinear(module):
108108
# of the LM init, after all other modules are initialized and properly dtyped.
109109
# In any case that should happen before loading the state dict to avoid a loss of precision.
110110
child.float()
111+
elif hasattr(module, 'modules') and hasattr(child, 'frozen_W'):
112+
# This is likely a LoRALinear layer
113+
# We don't want to replace it directly, as it will be handled by the LoRA-specific code
114+
# Instead, we replace its frozen_W with a QLinear
115+
try:
116+
from ..modules.lora import LoRALinear
117+
if isinstance(child, LoRALinear):
118+
child.frozen_W = QLinear(child.frozen_W)
119+
except ImportError:
120+
# If we can't import LoRALinear, just skip this layer
121+
pass
111122
else:
112123
replace_linear_with_qlinear(child)

0 commit comments

Comments
 (0)