diff --git a/pae/__init__.py b/pae/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pae/util/__init__.py b/pae/util/__init__.py index 8bb3985..d504900 100644 --- a/pae/util/__init__.py +++ b/pae/util/__init__.py @@ -1 +1,2 @@ +from .setup import setup, clear from .metric import Metric, BNLLLoss \ No newline at end of file diff --git a/pae/util/setup.py b/pae/util/setup.py new file mode 100644 index 0000000..d448f0f --- /dev/null +++ b/pae/util/setup.py @@ -0,0 +1,136 @@ +import gc +import glob +import logging +import os +import random +from pathlib import Path +from functools import partial +from datetime import datetime + +import numpy +import torch +import wandb + + +def allow_print_to_master(is_master): + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + + if force or is_master: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def check_need_init(): + if os.environ.get('INITIALIZED', None): + return False + else: + return True + + +def init_distributed_mode(args): + os.environ['INITIALIZED'] = 'TRUE' + os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' + os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda + print(f'{datetime.now().strftime("[%Y/%m/%d %H:%M]")} ', end='') + + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.distributed = True + args.rank = int(os.environ['RANK']) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + args.dist_backend = 'nccl' + args.dist_url = 'env://' + + print(f'| distributed init (rank {args.rank}): {args.dist_url}') + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + else: + print('| Not using distributed mode') + args.distributed = False + args.gpu = 0 + + args.is_rank_zero = args.gpu == 0 + allow_print_to_master(args.is_rank_zero) + torch.cuda.set_device(args.gpu) + args.device = torch.device(f'cuda:{args.gpu}') + + +def make_logger(log_file_path): + logger = logging.getLogger() + logger.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s | %(message)s", "[%Y/%m/%d %H:%M]") + + console = logging.StreamHandler() + console.setLevel(logging.INFO) + console.setFormatter(formatter) + logger.addHandler(console) + + file_handler = logging.FileHandler(filename=log_file_path) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +def log(msg, metric=False, logger=None): + if logger: + if metric: + wandb.log(msg) + else: + logger.info(msg) + + +def init_logger(args): + args.exp_name = '_'.join(str(getattr(args, target)) for target in args.exp_target) + args.version_id = len(list(glob.glob(os.path.join(args.output_dir, f'{args.exp_name}_v*')))) + args.exp_name = f'{args.exp_name}_v{args.version_id}' + + args.start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + args.log_dir = os.path.join(args.output_dir, args.exp_name) + args.text_log_path = os.path.join(args.log_dir, 'log.txt') + args.best_weight_path = os.path.join(args.log_dir, 'best_weight.pth') + + if args.is_rank_zero: + Path(args.log_dir).mkdir(parents=True, exist_ok=True) + args.logger = make_logger(args.text_log_path) + if args.use_wandb: + wandb.init(project=args.project_name, name=args.exp_name, entity=args.who, config=args, reinit=True) + else: + args.logger = None + + args.log = partial(log, logger=args.logger) + + +def clear(args): + # 1. clear gpu memory + torch.cuda.empty_cache() + # 2. clear cpu memory + gc.collect() + # 3. close logger + if args.is_rank_zero: + handlers = args.logger.handlers[:] + for handler in handlers: + args.logger.removeHandler(handler) + handler.close() + if args.use_wandb: + wandb.finish(quiet=True) + + +def setup(args): + if check_need_init(): + init_distributed_mode(args) + init_logger(args) + + if args.seed is not None: + numpy.random.seed(args.seed) + random.seed(args.seed) + torch.manual_seed(args.seed) + + torch.backends.cudnn.benchmark = True \ No newline at end of file diff --git a/sample.py b/sample.py index ba980f3..3f2ca1e 100644 --- a/sample.py +++ b/sample.py @@ -3,17 +3,9 @@ from pathlib import Path from torchvision.utils import save_image -from tqdm import tqdm - import torch -import torchvision.transforms as TVT -from torch.optim import AdamW -from torch.utils.data import DataLoader -from torchvision.datasets import MNIST -from pae.dataset import _dynamically_binarize from pae.model import NADE -from pae.util import Metric, BNLLLoss def get_arg_parser(): diff --git a/train.py b/train.py index 0094f1f..aedaf55 100644 --- a/train.py +++ b/train.py @@ -1,28 +1,31 @@ import os from argparse import ArgumentParser -from pathlib import Path - -from torchvision.utils import save_image -from tqdm import tqdm import torch import torchvision.transforms as TVT +import wandb from torch.optim import AdamW from torch.utils.data import DataLoader from torchvision.datasets import MNIST +from torchvision.utils import save_image from pae.dataset import _dynamically_binarize from pae.model import NADE -from pae.util import Metric, BNLLLoss +from pae.util import Metric, BNLLLoss, setup, clear def get_arg_parser(): # 1. setting parser = ArgumentParser(description='pytorch-auto-encoder') parser.add_argument('--data-dir', type=str, default=os.path.join('data', 'mnist'), help='root path of dataset') - parser.add_argument('--log-dir', type=str, default='log', help='root log dir') + parser.add_argument('--output-dir', type=str, default='log', help='root log dir') + parser.add_argument('--who', type=str, default="hankyul2", help="entity name used for logger") parser.add_argument('--project-name', type=str, default="pytorch-ae", help="project name used for logger") + parser.add_argument('--exp-target', type=str, default=['model_name'], help="arguments for experiment name") parser.add_argument('--cuda', type=str, default='0,', help="cuda devices") + parser.add_argument('--print-freq', type=int, default=20, help='print log frequency') + parser.add_argument('--use-wandb', action='store_true', help="use wandb to log metric") + parser.add_argument('--seed', type=int, default=42, help='fix randomness for better reproducibility') # 2. model parser.add_argument('-m', '--model-name', type=str, default='NADE', help='the name of model') @@ -37,28 +40,31 @@ def get_arg_parser(): @torch.no_grad() -def validate(dataloader, f, critic, device, epoch, log_dir): +def validate(dataloader, f, critic, args, epoch): loss_m = Metric(header='Loss:') - prog_bar = tqdm(dataloader, leave=True) - for x, y in prog_bar: - x = x.to(device) + total_iter = len(dataloader) + for batch_idx, (x, y) in enumerate(dataloader): + x = x.to(args.device) y_hat, x_hat = f(x) loss = critic(y_hat, x) loss_m.update(loss, len(x)) - prog_bar.set_description(f"Val {loss_m}") - save_image(x_hat[:16], os.path.join(log_dir, f"val_{epoch}.jpg")) + if batch_idx and args.print_freq and batch_idx % args.print_freq == 0: + num_digits = len(str(total_iter)) + args.log(f"VALID({epoch:03}): [{batch_idx:>{num_digits}}/{total_iter}] {loss_m}") + + save_image(x_hat[:16], os.path.join(args.log_dir, f"val_{epoch}.jpg")) return loss_m.compute() -def train(dataloader, f, critic, optim, device): +def train(dataloader, f, critic, optim, args, epoch): loss_m = Metric(header='Loss:') - prog_bar = tqdm(dataloader, leave=True) - for x, y in prog_bar: - x = x.to(device) + total_iter = len(dataloader) + for batch_idx, (x, y) in enumerate(dataloader): + x = x.to(args.device) y_hat, x_hat = f(x) loss = critic(y_hat, x) @@ -67,22 +73,17 @@ def train(dataloader, f, critic, optim, device): optim.zero_grad() loss_m.update(loss, len(x)) - prog_bar.set_description(f"Train {loss_m}") - - return loss_m.compute() + if batch_idx and args.print_freq and batch_idx % args.print_freq == 0: + num_digits = len(str(total_iter)) + args.log(f"TRAIN({epoch:03}): [{batch_idx:>{num_digits}}/{total_iter}] {loss_m}") -def setup(args): - os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' - os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda - args.device = 'cuda:0' + return loss_m.compute() - log_root_path = os.path.join(args.log_dir, args.model_name) - Path(log_root_path).mkdir(exist_ok=True, parents=True) - args.run_id = f"v{len(os.listdir(log_root_path))}" - args.log_dir = os.path.join(log_root_path, args.run_id) - Path(args.log_dir).mkdir(exist_ok=True, parents=True) +def sample(f, args, epoch): + sampled_img = f.sample(16, args.device).reshape(16, 1, 28, 28) + save_image(sampled_img, os.path.join(args.log_dir, f'sample_{epoch}.jpg')) def run(args): @@ -103,17 +104,26 @@ def run(args): critic = BNLLLoss() best_loss = 1000.0 + num_digit = len(str(args.epoch)) for epoch in range(args.epoch): - train_loss = train(train_dataloader, f, critic, optim, args.device) - val_loss = validate(val_dataloader, f, critic, args.device, epoch, args.log_dir) - sampled_img = f.sample(16, args.device).reshape(16, 1, 28, 28) - save_image(sampled_img, os.path.join(args.log_dir, f'sample_{epoch}.jpg')) + train_loss = train(train_dataloader, f, critic, optim, args, epoch) + val_loss = validate(val_dataloader, f, critic, args, epoch) + sample(f, args, epoch) + + args.log(f"EPOCH({epoch:>{num_digit}}/{args.epoch}): Train Loss: {train_loss:.04f} Val Loss: {val_loss:.04f}") + if args.use_wandb: + args.log({ + 'train_loss':train_loss, 'val_loss':val_loss, + 'val_img': wandb.Image(os.path.join(args.log_dir, f'val_{epoch}.jpg')), + 'sample_img': wandb.Image(os.path.join(args.log_dir, f'sample_{epoch}.jpg')) + }, metric=True) if best_loss > val_loss: best_loss = val_loss state_dict = {k: v.cpu() for k, v in f.state_dict().items()}, torch.save(state_dict, os.path.join(args.log_dir, f'{args.model_name}.pth')) - print(f"saved model (val loss: {best_loss:0.4f}) in to {args.log_dir}") + args.log(f"Saved model (val loss: {best_loss:0.4f}) in to {args.log_dir}") + clear(args) if __name__ == '__main__':