diff --git a/mergekit/scripts/extract_lora.py b/mergekit/scripts/extract_lora.py index fe890c13..2b4f1524 100644 --- a/mergekit/scripts/extract_lora.py +++ b/mergekit/scripts/extract_lora.py @@ -227,6 +227,7 @@ def arguments(self) -> Dict[str, Any]: def execute(self, task_vector: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.transpose: task_vector = task_vector.T + out_dtype = task_vector.dtype u, s, vh = torch.linalg.svd( task_vector.to(dtype=torch.float32), full_matrices=False ) @@ -244,9 +245,7 @@ def execute(self, task_vector: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor weight_a = scale_a @ vh[:rank] weight_b = u[:, :rank] @ scale_b - return weight_a.to(dtype=task_vector.dtype), weight_b.to( - dtype=task_vector.dtype - ) + return weight_a.to(dtype=out_dtype), weight_b.to(dtype=out_dtype) def group_label(self) -> Optional[str]: return self.input_task.group_label()