Skip to content

Commit a364fee

Browse files
committed
[ILUVATAR_GPU] Support dpo
1 parent 52712c8 commit a364fee

File tree

2 files changed

+107
-0
lines changed

2 files changed

+107
-0
lines changed

ernie/loss/dpo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,8 @@ def forward(
312312
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, score_deltas
313313
)
314314
loss = dpo_loss + sft_loss
315+
if "iluvatar" in paddle.get_device():
316+
paddle.device.empty_cache()
315317
if self.use_infohub:
316318
infohub.policy_chosen_logps.append(policy_chosen_logps.detach())
317319
infohub.policy_rejected_logps.append(policy_rejected_logps.detach())
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
### data
2+
train_dataset_type: "erniekit"
3+
eval_dataset_type: "erniekit"
4+
train_dataset_path: "./examples/data/dpo-train.jsonl"
5+
train_dataset_prob: "1.0"
6+
eval_dataset_path: "./examples/data/dpo-eval.jsonl"
7+
eval_dataset_prob: "1.0"
8+
max_seq_len: 8192
9+
num_samples_each_epoch: 6000000
10+
11+
### model
12+
model_name_or_path: /home/tianyu.zhou/ERNIE-4.5-21B-A3B-Paddle
13+
moe_group: mp
14+
fine_tuning: LoRA
15+
lora_rank: 32
16+
lora_alpha: 128
17+
lora_plus_scale: 12
18+
rslora: True
19+
fuse_rope: True
20+
fuse_linear: True
21+
22+
### finetuning
23+
# base
24+
stage: DPO
25+
seed: 42
26+
do_train: True
27+
do_eval: True
28+
distributed_dataloader: True
29+
dataloader_num_workers: 4
30+
batch_size: 1
31+
num_train_epochs: 1
32+
max_steps: 800
33+
max_evaluate_steps: 10000
34+
eval_steps: 20000
35+
evaluation_strategy: epoch
36+
save_steps: 100
37+
save_total_limit: 5
38+
save_strategy: epoch
39+
logging_steps: 1
40+
release_grads: True
41+
gradient_accumulation_steps: 8
42+
logging_dir: ./vdl_log
43+
output_dir: ./output
44+
disable_tqdm: True
45+
46+
# train
47+
warmup_steps: 50
48+
learning_rate: 5.0e-7
49+
lr_scheduler_type: cosine
50+
min_lr: 5.0e-7
51+
layerwise_lr_decay_bound: 1.0
52+
attention_probs_dropout_prob: 0.1
53+
dropout_warmup_steps: 100
54+
55+
# loss
56+
offset_alpha: 0.0
57+
scale_loss: 8192
58+
59+
# optimizer
60+
weight_decay: 0.1
61+
adam_epsilon: 1.0e-8
62+
adam_beta1: 0.9
63+
adam_beta2: 0.95
64+
offload_optim: True
65+
66+
# performance
67+
use_sp_callback: True
68+
tensor_parallel_degree: 4
69+
tensor_parallel_config: "sync_param sync_grad sync_moment"
70+
pipeline_parallel_degree: 1
71+
sharding_parallel_degree: 1
72+
sharding: stage1
73+
sequence_parallel: True
74+
pipeline_parallel_config: disable_partial_send_recv enable_clear_every_step_cache disable_batch_p2p_comm
75+
recompute: True
76+
recompute_use_reentrant: True
77+
compute_type: bf16
78+
fp16_opt_level: O2
79+
amp_master_grad: True
80+
amp_custom_white_list:
81+
- "lookup_table"
82+
- "lookup_table_v2"
83+
- "flash_attn"
84+
- "matmul"
85+
- "matmul_v2"
86+
- "fused_gemm_epilogue"
87+
amp_custom_black_list:
88+
- "reduce_sum"
89+
- "softmax_with_cross_entropy"
90+
- "c_softmax_with_cross_entropy"
91+
- "elementwise_div"
92+
- "sin"
93+
- "cos"
94+
unified_checkpoint: True
95+
unified_checkpoint_config: async_save
96+
97+
use_flash_attention: True
98+
use_sparse_head_and_loss_fn: False
99+
use_attn_mask_startend_row_indices: False
100+
use_sparse_flash_attn: False
101+
moe_multimodal_dispatch_use_allgather: "v2-alltoall"
102+
device: iluvatar_gpu
103+
fuse_rms_norm: False
104+
105+

0 commit comments

Comments
 (0)