Skip to content

Commit

Permalink
feat: add wandb logger
Browse files Browse the repository at this point in the history
- add wandb logger
- save image to wandb
- remove tqdm
  • Loading branch information
hankyul2 committed Dec 24, 2022
1 parent 842a1f6 commit 73dade4
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 41 deletions.
Empty file added pae/__init__.py
Empty file.
1 change: 1 addition & 0 deletions pae/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .setup import setup, clear
from .metric import Metric, BNLLLoss
136 changes: 136 additions & 0 deletions pae/util/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import gc
import glob
import logging
import os
import random
from pathlib import Path
from functools import partial
from datetime import datetime

import numpy
import torch
import wandb


def allow_print_to_master(is_master):
import builtins as __builtin__

builtin_print = __builtin__.print

def print(*args, **kwargs):
force = kwargs.pop('force', False)

if force or is_master:
builtin_print(*args, **kwargs)

__builtin__.print = print


def check_need_init():
if os.environ.get('INITIALIZED', None):
return False
else:
return True


def init_distributed_mode(args):
os.environ['INITIALIZED'] = 'TRUE'
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
print(f'{datetime.now().strftime("[%Y/%m/%d %H:%M]")} ', end='')

if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.distributed = True
args.rank = int(os.environ['RANK'])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.dist_backend = 'nccl'
args.dist_url = 'env://'

print(f'| distributed init (rank {args.rank}): {args.dist_url}')
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
else:
print('| Not using distributed mode')
args.distributed = False
args.gpu = 0

args.is_rank_zero = args.gpu == 0
allow_print_to_master(args.is_rank_zero)
torch.cuda.set_device(args.gpu)
args.device = torch.device(f'cuda:{args.gpu}')


def make_logger(log_file_path):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s | %(message)s", "[%Y/%m/%d %H:%M]")

console = logging.StreamHandler()
console.setLevel(logging.INFO)
console.setFormatter(formatter)
logger.addHandler(console)

file_handler = logging.FileHandler(filename=log_file_path)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

return logger


def log(msg, metric=False, logger=None):
if logger:
if metric:
wandb.log(msg)
else:
logger.info(msg)


def init_logger(args):
args.exp_name = '_'.join(str(getattr(args, target)) for target in args.exp_target)
args.version_id = len(list(glob.glob(os.path.join(args.output_dir, f'{args.exp_name}_v*'))))
args.exp_name = f'{args.exp_name}_v{args.version_id}'

args.start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
args.log_dir = os.path.join(args.output_dir, args.exp_name)
args.text_log_path = os.path.join(args.log_dir, 'log.txt')
args.best_weight_path = os.path.join(args.log_dir, 'best_weight.pth')

if args.is_rank_zero:
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
args.logger = make_logger(args.text_log_path)
if args.use_wandb:
wandb.init(project=args.project_name, name=args.exp_name, entity=args.who, config=args, reinit=True)
else:
args.logger = None

args.log = partial(log, logger=args.logger)


def clear(args):
# 1. clear gpu memory
torch.cuda.empty_cache()
# 2. clear cpu memory
gc.collect()
# 3. close logger
if args.is_rank_zero:
handlers = args.logger.handlers[:]
for handler in handlers:
args.logger.removeHandler(handler)
handler.close()
if args.use_wandb:
wandb.finish(quiet=True)


def setup(args):
if check_need_init():
init_distributed_mode(args)
init_logger(args)

if args.seed is not None:
numpy.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)

torch.backends.cudnn.benchmark = True
8 changes: 0 additions & 8 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,9 @@
from pathlib import Path

from torchvision.utils import save_image
from tqdm import tqdm

import torch
import torchvision.transforms as TVT
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from pae.dataset import _dynamically_binarize
from pae.model import NADE
from pae.util import Metric, BNLLLoss


def get_arg_parser():
Expand Down
76 changes: 43 additions & 33 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
import os
from argparse import ArgumentParser
from pathlib import Path

from torchvision.utils import save_image
from tqdm import tqdm

import torch
import torchvision.transforms as TVT
import wandb
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.utils import save_image

from pae.dataset import _dynamically_binarize
from pae.model import NADE
from pae.util import Metric, BNLLLoss
from pae.util import Metric, BNLLLoss, setup, clear


def get_arg_parser():
# 1. setting
parser = ArgumentParser(description='pytorch-auto-encoder')
parser.add_argument('--data-dir', type=str, default=os.path.join('data', 'mnist'), help='root path of dataset')
parser.add_argument('--log-dir', type=str, default='log', help='root log dir')
parser.add_argument('--output-dir', type=str, default='log', help='root log dir')
parser.add_argument('--who', type=str, default="hankyul2", help="entity name used for logger")
parser.add_argument('--project-name', type=str, default="pytorch-ae", help="project name used for logger")
parser.add_argument('--exp-target', type=str, default=['model_name'], help="arguments for experiment name")
parser.add_argument('--cuda', type=str, default='0,', help="cuda devices")
parser.add_argument('--print-freq', type=int, default=20, help='print log frequency')
parser.add_argument('--use-wandb', action='store_true', help="use wandb to log metric")
parser.add_argument('--seed', type=int, default=42, help='fix randomness for better reproducibility')

# 2. model
parser.add_argument('-m', '--model-name', type=str, default='NADE', help='the name of model')
Expand All @@ -37,28 +40,31 @@ def get_arg_parser():


@torch.no_grad()
def validate(dataloader, f, critic, device, epoch, log_dir):
def validate(dataloader, f, critic, args, epoch):
loss_m = Metric(header='Loss:')
prog_bar = tqdm(dataloader, leave=True)
for x, y in prog_bar:
x = x.to(device)
total_iter = len(dataloader)
for batch_idx, (x, y) in enumerate(dataloader):
x = x.to(args.device)

y_hat, x_hat = f(x)
loss = critic(y_hat, x)

loss_m.update(loss, len(x))
prog_bar.set_description(f"Val {loss_m}")

save_image(x_hat[:16], os.path.join(log_dir, f"val_{epoch}.jpg"))
if batch_idx and args.print_freq and batch_idx % args.print_freq == 0:
num_digits = len(str(total_iter))
args.log(f"VALID({epoch:03}): [{batch_idx:>{num_digits}}/{total_iter}] {loss_m}")

save_image(x_hat[:16], os.path.join(args.log_dir, f"val_{epoch}.jpg"))

return loss_m.compute()


def train(dataloader, f, critic, optim, device):
def train(dataloader, f, critic, optim, args, epoch):
loss_m = Metric(header='Loss:')
prog_bar = tqdm(dataloader, leave=True)
for x, y in prog_bar:
x = x.to(device)
total_iter = len(dataloader)
for batch_idx, (x, y) in enumerate(dataloader):
x = x.to(args.device)

y_hat, x_hat = f(x)
loss = critic(y_hat, x)
Expand All @@ -67,22 +73,17 @@ def train(dataloader, f, critic, optim, device):
optim.zero_grad()

loss_m.update(loss, len(x))
prog_bar.set_description(f"Train {loss_m}")

return loss_m.compute()

if batch_idx and args.print_freq and batch_idx % args.print_freq == 0:
num_digits = len(str(total_iter))
args.log(f"TRAIN({epoch:03}): [{batch_idx:>{num_digits}}/{total_iter}] {loss_m}")

def setup(args):
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
args.device = 'cuda:0'
return loss_m.compute()

log_root_path = os.path.join(args.log_dir, args.model_name)
Path(log_root_path).mkdir(exist_ok=True, parents=True)

args.run_id = f"v{len(os.listdir(log_root_path))}"
args.log_dir = os.path.join(log_root_path, args.run_id)
Path(args.log_dir).mkdir(exist_ok=True, parents=True)
def sample(f, args, epoch):
sampled_img = f.sample(16, args.device).reshape(16, 1, 28, 28)
save_image(sampled_img, os.path.join(args.log_dir, f'sample_{epoch}.jpg'))


def run(args):
Expand All @@ -103,17 +104,26 @@ def run(args):
critic = BNLLLoss()

best_loss = 1000.0
num_digit = len(str(args.epoch))
for epoch in range(args.epoch):
train_loss = train(train_dataloader, f, critic, optim, args.device)
val_loss = validate(val_dataloader, f, critic, args.device, epoch, args.log_dir)
sampled_img = f.sample(16, args.device).reshape(16, 1, 28, 28)
save_image(sampled_img, os.path.join(args.log_dir, f'sample_{epoch}.jpg'))
train_loss = train(train_dataloader, f, critic, optim, args, epoch)
val_loss = validate(val_dataloader, f, critic, args, epoch)
sample(f, args, epoch)

args.log(f"EPOCH({epoch:>{num_digit}}/{args.epoch}): Train Loss: {train_loss:.04f} Val Loss: {val_loss:.04f}")
if args.use_wandb:
args.log({
'train_loss':train_loss, 'val_loss':val_loss,
'val_img': wandb.Image(os.path.join(args.log_dir, f'val_{epoch}.jpg')),
'sample_img': wandb.Image(os.path.join(args.log_dir, f'sample_{epoch}.jpg'))
}, metric=True)

if best_loss > val_loss:
best_loss = val_loss
state_dict = {k: v.cpu() for k, v in f.state_dict().items()},
torch.save(state_dict, os.path.join(args.log_dir, f'{args.model_name}.pth'))
print(f"saved model (val loss: {best_loss:0.4f}) in to {args.log_dir}")
args.log(f"Saved model (val loss: {best_loss:0.4f}) in to {args.log_dir}")
clear(args)


if __name__ == '__main__':
Expand Down

0 comments on commit 73dade4

Please sign in to comment.