Skip to content

Commit 6357e91

Browse files
authored
device = finetuning_args.device (#1379)
* text_sft_reader默认设置GPU设备,XPU需要传入device参数。
1 parent 2590403 commit 6357e91

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

erniekit/train/vl_sft/workflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ def compute_metrics(p):
700700
"use_train_part_sharding": finetuning_args.text_use_train_part_sharding,
701701
"rope_3d": model_args.rope_3d,
702702
"chat_template": preprocess_args.chat_template,
703+
"device": finetuning_args.device,
703704
}
704705

705706
text_sft_train_reader = create_pyreader(config_dataset_text)

0 commit comments

Comments
 (0)