-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Description
I've just started using nanochat on an iMac M4 32GB and I'm having the same problem as issue #427 where evals are taking too much memory, but in the process of doing this and killing the evals, it doesn't save the checkpoints even though it has finished training.
If I move the save_checkpoint() if block above the evaluate_model() if block in base_train.py it does save the checkpoint before performing the evals.
I'm not familiar enough with the code base to know whether I should be running it with different parameters or what other effects this move might have.
The save_checkpoint() block:
nanochat/scripts/base_train.py
Lines 338 to 360 in d58fcd9
| # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step | |
| if last_step or (step > 0 and step != args.resume_from_step and args.save_every > 0 and step % args.save_every == 0): | |
| save_checkpoint( | |
| checkpoint_dir, | |
| step, | |
| orig_model.state_dict(), # model parameters | |
| [opt.state_dict() for opt in optimizers], # optimizer states | |
| { # metadata saved as json | |
| "step": step, | |
| "val_bpb": val_bpb, # loss at last step | |
| "model_config": model_config_kwargs, | |
| "user_config": user_config, # inputs to the training script | |
| "device_batch_size": args.device_batch_size, | |
| "max_seq_len": args.max_seq_len, | |
| "dataloader_state_dict": dataloader_state_dict, | |
| "loop_state": { # all loop state (other than step) so that we can resume training | |
| "min_val_bpb": min_val_bpb, | |
| "smooth_train_loss": smooth_train_loss, | |
| "total_training_time": total_training_time, | |
| }, | |
| }, | |
| rank=ddp_rank, | |
| ) |
Needs to be before the evaluate_model() if block:
nanochat/scripts/base_train.py
Lines 301 to 315 in d58fcd9
| # once in a while: estimate the CORE metric (all ranks participate) | |
| # use the original uncompiled model because the inputs keep changing shape | |
| results = {} | |
| if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)): | |
| model.eval() | |
| with autocast_ctx: | |
| results = evaluate_model(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) | |
| print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") | |
| wandb_run.log({ | |
| "step": step, | |
| "total_training_flops": flops_so_far, | |
| "core_metric": results["core_metric"], | |
| "centered_results": results["centered_results"], | |
| }) | |
| model.train() |