Skip to content

Commit

Permalink
Tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Feb 25, 2025
1 parent ec8d95e commit 79c96ab
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions mergekit/scripts/extract_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def arguments(self) -> Dict[str, Any]:
def execute(self, task_vector: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.transpose:
task_vector = task_vector.T
out_dtype = task_vector.dtype
u, s, vh = torch.linalg.svd(
task_vector.to(dtype=torch.float32), full_matrices=False
)
Expand All @@ -244,9 +245,7 @@ def execute(self, task_vector: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor
weight_a = scale_a @ vh[:rank]
weight_b = u[:, :rank] @ scale_b

return weight_a.to(dtype=task_vector.dtype), weight_b.to(
dtype=task_vector.dtype
)
return weight_a.to(dtype=out_dtype), weight_b.to(dtype=out_dtype)

def group_label(self) -> Optional[str]:
return self.input_task.group_label()
Expand Down

0 comments on commit 79c96ab

Please sign in to comment.