diff --git a/train.py b/train.py index 9efece250581..390ff86a2593 100644 --- a/train.py +++ b/train.py @@ -24,7 +24,7 @@ from copy import deepcopy from datetime import datetime from pathlib import Path - +from pprint import pformat, pprint import numpy as np import torch import torch.distributed as dist @@ -65,6 +65,16 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) +def print_dict(d): + for key, value in d.items(): + LOGGER.info(f"{key}: {value}") + + +def print_namespace(ns): + for attr in vars(ns): + LOGGER.info(f"{attr}: {getattr(ns, attr)}") + + def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ @@ -384,6 +394,10 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio if best_fitness == fi: torch.save(ckpt, best) if opt.save_period > 0 and epoch % opt.save_period == 0: + LOGGER.info(f'Saving checkpoint at {epoch} epochs') + LOGGER.info('optの中身') + print_namespace(opt) + print_dict(ckpt) torch.save(ckpt, w / f'epoch{epoch}.pt') del ckpt callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)