Skip to content

Commit 8cae9d7

Browse files
committed
adapt fc
1 parent 5eb4e29 commit 8cae9d7

File tree

8 files changed

+803
-224
lines changed

8 files changed

+803
-224
lines changed

examples/pre-training/ernie/pretrain.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -296,16 +296,16 @@ def formatv(v):
296296
and not args.overwrite_output_dir
297297
):
298298
last_checkpoint = get_last_checkpoint(args.output_dir)
299-
if last_checkpoint is None and len(os.listdir(args.output_dir)) > 0:
300-
raise ValueError(
301-
f"Output directory ({args.output_dir}) already exists and is not empty. "
302-
"Use --overwrite_output_dir to overcome."
303-
)
304-
elif last_checkpoint is not None and args.resume_from_checkpoint is None:
305-
logger.info(
306-
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
307-
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
308-
)
299+
# if last_checkpoint is None and len(os.listdir(args.output_dir)) > 0:
300+
# raise ValueError(
301+
# f"Output directory ({args.output_dir}) already exists and is not empty. "
302+
# "Use --overwrite_output_dir to overcome."
303+
# )
304+
# elif last_checkpoint is not None and args.resume_from_checkpoint is None:
305+
# logger.info(
306+
# f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
307+
# "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
308+
# )
309309

310310
def compute_metrics(p):
311311
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
@@ -439,6 +439,7 @@ def sname_to_tname(pp_model):
439439
cfg.token_balance_seqlen = args.max_seq_length * args.per_device_train_batch_size
440440
cfg.fp16_opt_level = args.fp16_opt_level
441441
cfg.moe_group = args.moe_group
442+
cfg.moe_group_name = args.moe_group
442443
cfg.dtype = dtype
443444
cfg.use_fp8 = args.use_fp8
444445
cfg.enable_mtp_magic_send = args.enable_mtp_magic_send
@@ -502,7 +503,7 @@ def sname_to_tname(pp_model):
502503
logger.info(f"using model type:{type(model)}")
503504
paddle.set_default_dtype("float32")
504505

505-
logger.info(f"using model={type(model)}, cfg={cfg}")
506+
# logger.info(f"using model={type(model)}, cfg={cfg}")
506507

507508
train_dataset, eval_dataset, test_dataset, data_collator = (
508509
create_pretrained_dataset(args)

examples/pre-training/ernie/src/trainers/pretraining_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1260,7 +1260,7 @@ def _maybe_log_save_evaluate(
12601260
)
12611261
logs["learning_rate"] = float(self._get_learning_rate())
12621262
logs["global_step"] = int(self.state.global_step)
1263-
1263+
logs["loss_md5"] = paddle.to_tensor(logs["loss"])._md5sum()
12641264
divisor = 2**30
12651265

12661266
current_device = framework._current_expected_place_()

examples/pre-training/models/ernie/configuration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(
8080
initializer_range=0.02,
8181
rms_norm_eps=1e-6,
8282
use_cache=False,
83-
use_flash_attn=True,
83+
use_flash_attn=False,
8484
use_mem_eff_attn=False,
8585
use_flash_attn_with_mask=False,
8686
use_recompute=False,
@@ -149,6 +149,7 @@ def __init__(
149149
global_aux_loss=False,
150150
moe_dropout_prob=0.0,
151151
moe_group="world",
152+
moe_group_name="world",
152153
num_experts_per_tok: int = 8,
153154
moe_intermediate_size: Union[int, list] = 0,
154155
moe_num_shared_experts: int = 0,
@@ -356,6 +357,7 @@ def update_nested_dict(default_dict, update_dict):
356357
self.moe_layer_interval = moe_layer_interval
357358
self.moe_dropout_prob = moe_dropout_prob
358359
self.moe_group = moe_group
360+
self.moe_group_name = moe_group_name
359361
self.num_experts_per_tok = num_experts_per_tok
360362
self.moe_num_shared_experts = moe_num_shared_experts
361363
self.moe_num_dense_experts = moe_num_dense_experts
@@ -395,7 +397,6 @@ def update_nested_dict(default_dict, update_dict):
395397

396398
self.use_linear_residual_norm_recompute = use_linear_residual_norm_recompute
397399
self.use_rms_qkv_recompute = use_rms_qkv_recompute
398-
399400
assert aux_loss_type in ["", "default", "seq_aux_loss", "switch_aux_loss"]
400401
self.aux_loss_type = aux_loss_type
401402

0 commit comments

Comments
 (0)