-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_react.py
More file actions
179 lines (158 loc) · 6.29 KB
/
train_react.py
File metadata and controls
179 lines (158 loc) · 6.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# import unsloth
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from transformers import HfArgumentParser, set_seed
from trl import SFTConfig, SFTTrainer
from utils import create_and_prepare_model, create_datasets
import re, json
from tqdm import tqdm
from collections import Counter
from unsloth import FastLanguageModel
# Define and parse arguments.
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
chat_template_format: Optional[str] = field(
default="none",
metadata={
"help": "chatml|zephyr|none. Pass `none` if the dataset is already formatted with the chat template."
},
)
lora_alpha: Optional[int] = field(default=16)
lora_dropout: Optional[float] = field(default=0.1)
lora_r: Optional[int] = field(default=64)
lora_target_modules: Optional[str] = field(
default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
metadata={"help": "comma separated list of target modules to apply LoRA layers to"},
)
use_nested_quant: Optional[bool] = field(
default=False,
metadata={"help": "Activate nested quantization for 4bit base models"},
)
bnb_4bit_compute_dtype: Optional[str] = field(
default="float16",
metadata={"help": "Compute dtype for 4bit base models"},
)
bnb_4bit_quant_storage_dtype: Optional[str] = field(
default="uint8",
metadata={"help": "Quantization storage dtype for 4bit base models"},
)
bnb_4bit_quant_type: Optional[str] = field(
default="nf4",
metadata={"help": "Quantization type fp4 or nf4"},
)
use_flash_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enables Flash attention for training."},
)
use_peft_lora: Optional[bool] = field(
default=False,
metadata={"help": "Enables PEFT LoRA for training."},
)
use_8bit_quantization: Optional[bool] = field(
default=False,
metadata={"help": "Enables loading model in 8bit."},
)
use_4bit_quantization: Optional[bool] = field(
default=False,
metadata={"help": "Enables loading model in 4bit."},
)
use_reentrant: Optional[bool] = field(
default=False,
metadata={"help": "Gradient Checkpointing param. Refer the related docs"},
)
use_unsloth: Optional[bool] = field(
default=False,
metadata={"help": "Enables UnSloth for training."},
)
@dataclass
class DataTrainingArguments:
dataset_name: Optional[str] = field(
default="timdettmers/openassistant-guanaco",
metadata={"help": "The preference dataset to use."},
)
append_concat_token: Optional[bool] = field(
default=False,
metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},
)
add_special_tokens: Optional[bool] = field(
default=False,
metadata={"help": "If True, tokenizers adds special tokens to each sample being packed."},
)
splits: Optional[str] = field(
default="train,test",
metadata={"help": "Comma separate list of the splits to use from the dataset."},
)
data_format: Optional[str] = field(
default="reason",
metadata={"help": "to distinguish whether the data is used to train reasoner or baseline only"},
)
def main(model_args, data_args, training_args):
# print(f"model_args: {model_args}")
# print(f"data_args: {data_args}")
# print(f"training_args: {training_args}")
# Set seed for reproducibility
set_seed(training_args.seed)
# model
model, peft_config, tokenizer = create_and_prepare_model(model_args, data_args, training_args)
# 检查 lm_head 和 embed_tokens 是否共享权重
if model.get_output_embeddings().weight is model.get_input_embeddings().weight:
print("lm_head 和 embed_tokens 共享权重")
else:
print("lm_head 和 embed_tokens 不共享权重")
for name, param in model.named_parameters():
# print(name)
if "embed_tokens" in name:
print(name, param.requires_grad) # 应该为True
print(f"Model embedding size: {model.get_input_embeddings().weight.size(0)}")
print(f"Tokenizer length: {len(tokenizer)}")
# gradient ckpt
model.config.use_cache = not training_args.gradient_checkpointing
training_args.gradient_checkpointing = training_args.gradient_checkpointing and not model_args.use_unsloth
if training_args.gradient_checkpointing:
training_args.gradient_checkpointing_kwargs = {"use_reentrant": model_args.use_reentrant}
training_args.dataset_kwargs = {
"append_concat_token": data_args.append_concat_token,
"add_special_tokens": data_args.add_special_tokens,
}
# datasets
train_dataset, eval_dataset = create_datasets(
tokenizer,
data_args,
training_args,
apply_chat_template=model_args.chat_template_format != "none",
)
# trainer
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
)
trainer.accelerator.print(f"{trainer.model}")
if hasattr(trainer.model, "print_trainable_parameters"):
trainer.model.print_trainable_parameters()
# train
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model()
if __name__ == "__main__":
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SFTConfig))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
main(model_args, data_args, training_args)