Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PEFT add_weighted_adapter() Function for Merging Multiple Adapters #6310

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified assets/wechat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified assets/wechat_npu.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
#### Supervised Fine-Tuning on Multiple Nodes

```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
```

#### Multimodal Supervised Fine-Tuning
Expand Down
4 changes: 2 additions & 2 deletions examples/README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
#### 在多机上进行指令监督微调

```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
```

#### 多模态指令监督微调
Expand Down
8 changes: 8 additions & 0 deletions src/llamafactory/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
)
},
)
combination_type: Optional[str] = field(
default=None,
metadata={"help": "The merging type can be one of ['cat','svd','linear']"},
)
combination_weights: Optional[float] = field(
default=None,
metadata={"help": "List of weights for each adapter. "},
)
adapter_folder: Optional[str] = field(
default=None,
metadata={"help": "The folder containing the adapter weights to load."},
Expand Down
56 changes: 53 additions & 3 deletions src/llamafactory/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import TYPE_CHECKING

import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model, PeftModelForCausalLM
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled

Expand Down Expand Up @@ -177,10 +177,60 @@ def _setup_lora_tuning(
"token": model_args.hf_hub_token,
}

for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
if len(adapter_to_merge) > 1 and model_args.combination_type is not None:
if model_args.combination_weights is None :
raise ValueError(f"Combination_weights must be provided, if you use '{model_args.combination_type}' to merge lora adapters.")
elif len(model_args.combination_weights) != len(adapter_to_merge):
raise ValueError(f"The number of combination_weights must be consistent with the number of adapters")

weights = model_args.combination_weights
index = 0
adapter_names = []
for idx, adapter in enumerate(adapter_to_merge):
adapter_name = 'ad_' + str(index)
print(adapter_name)
if idx == 0:
model = PeftModelForCausalLM.from_pretrained(model, adapter , adapter_name="ad_0")
else:
model.load_adapter(adapter, adapter_name)
adapter_names.append(adapter_name)
index += 1
# Since the merge_and_unload() operation will be performed according to the original structure after the LoRA is merged.
# The LoRA weight will be scaled at that step.
# So the weight will be adjusted during the merge to eliminate the impact of the scaling operation during the adapter merge on the merge weight.
adapter_scaling = []
for adapter_name in adapter_names:
adapter_config = model.peft_config[adapter_name]
adapter_scaling.append(adapter_config.lora_alpha / adapter_config.r)

weighted_adapter_name = "merged_weighted_ad"
if model_args.combination_type in ['cat','svd']:
weights = [wi / ai for ai, wi in zip(adapter_scaling, weights)]
model.add_weighted_adapter(
adapters = adapter_names,
weights = weights,
adapter_name = weighted_adapter_name,
combination_type = model_args.combination_type,
)
elif model_args.combination_type in ['linear']:
weights = [wi**2/ai for ai, wi in zip(adapter_scaling, weights)]
print(weights)
model.add_weighted_adapter(
adapters = adapter_names,
weights = weights,
adapter_name = weighted_adapter_name,
combination_type = model_args.combination_type,
)
model.set_adapter(weighted_adapter_name)
for name in adapter_names:
model.delete_adapter(name)
model = model.merge_and_unload()

else:
for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()

if len(adapter_to_merge) > 0:
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")

Expand Down