Skip to content

Move save_checkpoint() before evaluate_model() #446

@jolonf

Description

@jolonf

I've just started using nanochat on an iMac M4 32GB and I'm having the same problem as issue #427 where evals are taking too much memory, but in the process of doing this and killing the evals, it doesn't save the checkpoints even though it has finished training.

If I move the save_checkpoint() if block above the evaluate_model() if block in base_train.py it does save the checkpoint before performing the evals.

I'm not familiar enough with the code base to know whether I should be running it with different parameters or what other effects this move might have.

The save_checkpoint() block:

# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0):
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(), # model parameters
[opt.state_dict() for opt in optimizers], # optimizer states
{ # metadata saved as json
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script
"device_batch_size": args.device_batch_size,
"max_seq_len": args.max_seq_len,
"dataloader_state_dict": dataloader_state_dict,
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,
"smooth_train_loss": smooth_train_loss,
"total_training_time": total_training_time,
},
},
rank=ddp_rank,
)

Needs to be before the evaluate_model() if block:

# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
results = {}
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"core_metric": results["core_metric"],
"centered_results": results["centered_results"],
})
model.train()

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions