Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA extraction fixes #522

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions mergekit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ class ModelReference(BaseModel, frozen=True):
override_architecture: Optional[str] = None

def merged(
self, cache_dir: Optional[str] = None, trust_remote_code: bool = False
self,
cache_dir: Optional[str] = None,
trust_remote_code: bool = False,
lora_merge_dtype: Optional[str] = None,
) -> "ModelReference":
"""Merge the LoRA if applicable and return a reference to the result."""
if not self.lora:
Expand All @@ -95,7 +98,7 @@ def merged(
model = auto_cls.from_pretrained(
self.model.path,
revision=self.model.revision,
torch_dtype=torch.float16,
torch_dtype=dtype_from_name(lora_merge_dtype),
low_cpu_mem_usage=True,
trust_remote_code=trust_remote_code,
)
Expand Down
6 changes: 5 additions & 1 deletion mergekit/io/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class LoaderCache:
hf_cache_dir: Optional[str] = None
lazy_unpickle: bool = False
trust_remote_code: bool = False
lora_merge_dtype: Optional[str] = None

# singleton instance per thread
_instance = threading.local()
Expand All @@ -34,7 +35,9 @@ def __new__(cls) -> "LoaderCache":
def get(self, model: ModelReference) -> LazyTensorLoader:
if model not in self.loaders:
merged = model.merged(
cache_dir=self.lora_cache_dir, trust_remote_code=self.trust_remote_code
cache_dir=self.lora_cache_dir,
trust_remote_code=self.trust_remote_code,
lora_merge_dtype=self.lora_merge_dtype,
)
self.loaders[model] = merged.lazy_loader(
cache_dir=self.hf_cache_dir, lazy_unpickle=self.lazy_unpickle
Expand All @@ -50,6 +53,7 @@ def setup(self, options: MergeOptions):
self.hf_cache_dir = options.transformers_cache
self.lazy_unpickle = options.lazy_unpickle
self.trust_remote_code = options.trust_remote_code
self.lora_merge_dtype = options.lora_merge_dtype


shard_name_re = re.compile(r"model\-([0-9]+)-of-([0-9]+)")
Expand Down
1 change: 0 additions & 1 deletion mergekit/merge_methods/multislerp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def multislerp(

mean = (unit_tensors * weights.view(-1, 1)).sum(0)
mean_norm = torch.norm(mean)
print(mean_norm)
if mean_norm < eps:
if tensors.shape[0] == 2:
# fallback to linear interpolation
Expand Down
2 changes: 0 additions & 2 deletions mergekit/merge_methods/nuslerp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor:
weights = [self.tensor_parameters[key]["weight"] for key in keys]

if len(tensors) != 2:
print(keys)
print(self.base_model)
raise RuntimeError(
"NuSlerp merge expects exactly two models (plus optional base model)"
)
Expand Down
66 changes: 64 additions & 2 deletions mergekit/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import click
import torch
import transformers
from click.core import Context, Parameter
from pydantic import BaseModel
from click.core import Context, HelpFormatter, Parameter
from pydantic import BaseModel, model_validator

from mergekit.common import parse_kmb

Expand All @@ -19,6 +19,7 @@ class MergeOptions(BaseModel, frozen=True):
allow_crimes: bool = False
transformers_cache: Optional[str] = None
lora_merge_cache: Optional[str] = None
lora_merge_dtype: Optional[str] = None
cuda: bool = False
low_cpu_memory: bool = False
out_shard_size: int = parse_kmb("5B")
Expand All @@ -34,6 +35,7 @@ class MergeOptions(BaseModel, frozen=True):
read_to_gpu: bool = False
multi_gpu: bool = False
num_threads: Optional[int] = None
gpu_rich: bool = False

def apply_global_options(self):
logging.basicConfig(level=logging.INFO if self.verbose else logging.WARNING)
Expand All @@ -43,11 +45,21 @@ def apply_global_options(self):
torch.set_num_threads(self.num_threads)
torch.set_num_interop_threads(self.num_threads)

@model_validator(mode="before")
def handle_gpu_rich(cls, value):
if isinstance(value, dict) and value.get("gpu_rich"):
value["cuda"] = True
value["low_cpu_memory"] = True
value["read_to_gpu"] = True
value["multi_gpu"] = True
return value


OPTION_HELP = {
"allow_crimes": "Allow mixing architectures",
"transformers_cache": "Override storage path for downloaded models",
"lora_merge_cache": "Path to store merged LORA models",
"lora_merge_dtype": "Override dtype when applying LoRAs",
"cuda": "Perform matrix arithmetic on GPU",
"low_cpu_memory": "Store results and intermediate values on GPU. Useful if VRAM > RAM",
"out_shard_size": "Number of parameters per output shard [default: 5B]",
Expand All @@ -63,6 +75,30 @@ def apply_global_options(self):
"multi_gpu": "Use multi-gpu parallel graph execution engine",
"num_threads": "Number of threads to use for parallel CPU operations",
"verbose": "Enable verbose logging",
"gpu_rich": "Alias for --cuda --low-cpu-memory --read-to-gpu --multi-gpu",
}

OPTION_CATEGORIES = {
"lora_merge_cache": "Storage",
"transformers_cache": "Storage",
"out_shard_size": "Output Settings",
"copy_tokenizer": "Output Settings",
"clone_tensors": "Output Settings",
"write_model_card": "Output Settings",
"safe_serialization": "Output Settings",
"lazy_unpickle": "Performance",
"cuda": "Performance",
"low_cpu_memory": "Performance",
"read_to_gpu": "Performance",
"multi_gpu": "Performance",
"num_threads": "Performance",
"gpu_rich": "Performance",
"trust_remote_code": "Dangerous Options",
"allow_crimes": "Dangerous Options",
"random_seed": "Miscellaneous",
"verbose": "Miscellaneous",
"quiet": "Miscellaneous",
"lora_merge_dtype": "Miscellaneous",
}


Expand Down Expand Up @@ -104,10 +140,14 @@ def wrapper(*args, **kwargs):
else:
arg_str = f"--{arg_name}"
param_decls = [arg_str]
kwargs = {}
if field_name == "verbose":
param_decls = ["--verbose/--no-verbose", "-v"]
if field_name == "num_threads":
param_decls = ["--num-threads", "-j"]
if field_name == "gpu_rich":
param_decls = ["--gpu-rich"]
kwargs["is_flag"] = True

help_str = OPTION_HELP.get(field_name, None)
wrapper = click.option(
Expand All @@ -116,6 +156,28 @@ def wrapper(*args, **kwargs):
default=info.default,
help=help_str,
show_default=field_name != "out_shard_size",
**kwargs,
)(wrapper)

return wrapper


class PrettyPrintHelp(click.Command):
def format_options(self, ctx: Context, formatter: HelpFormatter) -> None:
categories = {None: []}
for param in ctx.command.params:
if param.name in OPTION_CATEGORIES:
category = OPTION_CATEGORIES[param.name]
if category not in categories:
categories[category] = []
categories[category].append(param)
else:
categories[None].append(param)

for category, params in categories.items():
title = category or "Script Options"
opts = [p.get_help_record(ctx) for p in params]
opts = [opt for opt in opts if opt is not None]
if opts:
with formatter.section(title):
formatter.write_dl(opts)
1 change: 1 addition & 0 deletions mergekit/scripts/evolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def _reshard_model(
merged = model.merged(
cache_dir=merge_cache,
trust_remote_code=trust_remote_code,
lora_merge_dtype="bfloat16",
)
out_path = os.path.join(
storage_path,
Expand Down
61 changes: 43 additions & 18 deletions mergekit/scripts/extract_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
from mergekit.io.tasks import FinalizeModel, LoadTensor, SaveTensor, TensorWriterTask
from mergekit.io.tensor_writer import TensorWriter
from mergekit.multigpu_executor import MultiGPUExecutor
from mergekit.options import MergeOptions, add_merge_options
from mergekit.options import MergeOptions, PrettyPrintHelp, add_merge_options

logger = logging.getLogger("extract_lora")


@click.command("mergekit-extract-lora")
@click.command("mergekit-extract-lora", cls=PrettyPrintHelp)
@click.option(
"--model",
required=True,
Expand Down Expand Up @@ -59,7 +59,7 @@
"--embed-lora/--no-embed-lora",
is_flag=True,
default=False,
help="Extract LoRA weights for embeddings",
help="Extract LoRA weights for embeddings (vs. in modules_to_save)",
)
@click.option(
"--save-module",
Expand Down Expand Up @@ -92,6 +92,12 @@
help="Threshold for singular values to discard",
show_default=True,
)
@click.option(
"--skip-undecomposable",
is_flag=True,
help="Skip saving undecomposable modules",
default=False,
)
@add_merge_options
def main(
base_model: str,
Expand All @@ -104,6 +110,7 @@ def main(
exclude_regexes: List[str],
include_regexes: List[str],
sv_epsilon: float,
skip_undecomposable: bool,
merge_options: MergeOptions,
):
merge_options.apply_global_options()
Expand All @@ -117,10 +124,12 @@ def main(
base_model_ref=base_model_ref.merged(
cache_dir=merge_options.lora_merge_cache,
trust_remote_code=merge_options.trust_remote_code,
lora_merge_dtype=merge_options.lora_merge_dtype,
),
model_ref=model_ref.merged(
cache_dir=merge_options.lora_merge_cache,
trust_remote_code=merge_options.trust_remote_code,
lora_merge_dtype=merge_options.lora_merge_dtype,
),
modules_to_save=modules_to_save,
out_path=out_path,
Expand Down Expand Up @@ -220,7 +229,10 @@ def arguments(self) -> Dict[str, Any]:
def execute(self, task_vector: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.transpose:
task_vector = task_vector.T
u, s, vh = torch.linalg.svd(task_vector, full_matrices=False)
out_dtype = task_vector.dtype
u, s, vh = torch.linalg.svd(
task_vector.to(dtype=torch.float32), full_matrices=False
)
rank = min(self.max_rank, s.shape[0])
if self.sv_epsilon > 0:
rank = min((s > self.sv_epsilon).sum().item(), rank)
Expand All @@ -235,7 +247,7 @@ def execute(self, task_vector: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor
weight_a = scale_a @ vh[:rank]
weight_b = u[:, :rank] @ scale_b

return weight_a, weight_b
return weight_a.to(dtype=out_dtype), weight_b.to(dtype=out_dtype)

def group_label(self) -> Optional[str]:
return self.input_task.group_label()
Expand Down Expand Up @@ -282,13 +294,14 @@ def execute(
f"No SVD decomposition for required weight {self.weight_info.name}"
)
return
lora_type = "lora_embedding" if self.weight_info.is_embed else "lora"
lora_type = "lora_embedding" if self.decomposition_task.transpose else "lora"
lora_suffix = ".weight" if not self.decomposition_task.transpose else ""
base_name = self.weight_info.name.removesuffix(".weight")
writer.save_tensor(
f"base_model.model.{base_name}.{lora_type}_A.weight", weight_a
f"base_model.model.{base_name}.{lora_type}_A{lora_suffix}", weight_a
)
writer.save_tensor(
f"base_model.model.{base_name}.{lora_type}_B.weight", weight_b
f"base_model.model.{base_name}.{lora_type}_B{lora_suffix}", weight_b
)

def priority(self) -> int:
Expand Down Expand Up @@ -327,6 +340,7 @@ def plan_extraction(
exclude_regexes: Optional[List[str]] = None,
include_regexes: Optional[List[str]] = None,
sv_epsilon: float = 0,
skip_undecomposable: bool = False,
) -> PlanResults:
targets = []
writer_task = TensorWriterTask(
Expand Down Expand Up @@ -357,15 +371,12 @@ def plan_extraction(

ft_vocab = embed_in.weight.shape[0]
base_vocab = dummy_base.get_input_embeddings().weight.shape[0]
if ft_vocab != base_vocab:
if ft_vocab != base_vocab and embed_lora:
logger.warning(
f"Vocabulary size mismatch: fine-tuned model has {ft_vocab} tokens, base model has {base_vocab} tokens"
)
logger.warning("Enforcing embeddings in modules_to_save, embed_lora=False")
embed_lora = False
force_embed_save = True
else:
force_embed_save = False

warned_modules = set()

Expand Down Expand Up @@ -394,9 +405,17 @@ def _should_extract(name: str) -> bool:
else:
continue

if (force_embed_save and (module == embed_in or module == embed_out)) or (
not embed_lora and isinstance(module, nn.Embedding)
if (
(not embed_lora)
and (
module == embed_in
or module == embed_out
or isinstance(module, nn.Embedding)
)
and not any(re.search(r, name) for r in exclude_regexes or [])
):
# If embeddings are not explicitly excluded but embed_lora is False,
# save them at full rank instead of decomposing
key = name.split(".")[-1]
if key not in modules_to_save:
logger.warning(f"Adding {key} to modules_to_save")
Expand All @@ -423,12 +442,18 @@ def _should_extract(name: str) -> bool:
)
else:
key = name.split(".")[-1]
message = (
f"{key} has unsupported module type {type(module).__name__} - "
+ ("skipping" if skip_undecomposable else "saving at full rank")
)
if not skip_undecomposable:
# into modules_to_save it goes
targets.extend(
plan_module_to_save(model_ref, writer_task, wi, bias_wi)
)
if key not in warned_modules:
logger.warning(message)
warned_modules.add(key)
generic_name = re.sub(r"\.(\d+)\.", ".N.", name)
logger.warning(
f"{generic_name} has unsupported module type {type(module).__name__} - skipping"
)

save_tasks = [t for t in targets if isinstance(t, (SaveTensor, LoRAModuleSaveTask))]
finalize = FinalizeModel(tensor_save_tasks=save_tasks, writer_task=writer_task)
Expand Down
4 changes: 2 additions & 2 deletions mergekit/scripts/layershuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
OutputSliceDefinition,
)
from mergekit.merge import run_merge
from mergekit.options import MergeOptions, add_merge_options
from mergekit.options import MergeOptions, PrettyPrintHelp, add_merge_options


@click.command("mergekit-layershuffle")
@click.command("mergekit-layershuffle", cls=PrettyPrintHelp)
@click.argument("out_path", type=str)
@click.option("--model", "-m", multiple=True, type=str, help="Add a model to the merge")
@click.option(
Expand Down
4 changes: 2 additions & 2 deletions mergekit/scripts/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

from mergekit.config import InputModelDefinition, MergeConfiguration
from mergekit.merge import run_merge
from mergekit.options import MergeOptions, add_merge_options
from mergekit.options import MergeOptions, PrettyPrintHelp, add_merge_options


@click.command("mergekit-legacy")
@click.command("mergekit-legacy", cls=PrettyPrintHelp)
@click.argument("out_path", type=str)
@click.option(
"--merge", "merge", type=str, multiple=True, help="Add a model to the merge"
Expand Down
Loading
Loading