-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
125 lines (98 loc) · 3.48 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from skimage.color import rgb2lab, lab2rgb
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
class make_data(Dataset):
def __init__(self, data_path, split, img_size=256):
if split == 'train':
self.transforms = transforms.Compose([
transforms.Resize((img_size, img_size), Image.BICUBIC),
transforms.RandomHorizontalFlip(),
])
elif split == 'val':
self.transforms = transforms.Resize((img_size, img_size), Image.BICUBIC)
self.data_path = data_path
self.split = split
self.img_size = img_size
def __getitem__(self, idx):
img = Image.open(self.data_path[idx]).convert("RGB")
img = self.transforms(img)
img = np.array(img)
img_to_lab = rgb2lab(img).astype("float32")
img_to_lab = transforms.ToTensor()(img_to_lab)
L = img_to_lab[[0], ...] / 50. - 1.
ab = img_to_lab[[1, 2], ...] / 110.
return {'L': L, 'ab': ab}
def __len__(self):
return len(self.data_path)
class AverageMeter:
def __init__(self):
self.reset()
def reset(self):
self.count, self.avg, self.sum = [0.] * 3
def update(self, val, count=1):
self.count += count
self.sum += count * val
self.avg = self.sum / self.count
def create_loss_meters():
loss_D_fake = AverageMeter()
loss_D_real = AverageMeter()
loss_D = AverageMeter()
loss_G_GAN = AverageMeter()
loss_G_L1 = AverageMeter()
loss_G = AverageMeter()
return {'loss_D_fake': loss_D_fake,
'loss_D_real': loss_D_real,
'loss_D': loss_D,
'loss_G_GAN': loss_G_GAN,
'loss_G_L1': loss_G_L1,
'loss_G': loss_G}
def update_losses(model, loss_meter_dict, count):
for loss_name, loss_meter in loss_meter_dict.items():
loss = getattr(model, loss_name)
loss_meter.update(loss.item(), count=count)
def lab_to_rgb(L, ab):
L = (L + 1.) * 50.
ab = ab * 110.
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
rgb_imgs = []
for img in Lab:
img_rgb = lab2rgb(img)
rgb_imgs.append(img_rgb)
return np.stack(rgb_imgs, axis=0)
def visualize(model, data, save=True):
model.G_net.eval()
with torch.no_grad():
model.setup_input(data)
model.forward()
model.G_net.train()
fake_color = model.fake_color.detach()
real_color = model.ab
L = model.L
fake_imgs = lab_to_rgb(L, fake_color)
real_imgs = lab_to_rgb(L, real_color)
fig = plt.figure(figsize=(15, 8))
for i in range(5):
ax = plt.subplot(3, 5, i + 1)
ax.imshow(L[i][0].cpu(), cmap='gray')
ax.axis("off")
ax = plt.subplot(3, 5, i + 1 + 5)
ax.imshow(fake_imgs[i])
ax.axis("off")
ax = plt.subplot(3, 5, i + 1 + 10)
ax.imshow(real_imgs[i])
ax.axis("off")
plt.show()
if save:
fig.savefig(f"colorization_{time.time()}.png")
def log_results(loss_meter_dict):
for loss_name, loss_meter in loss_meter_dict.items():
print(f"{loss_name}: {loss_meter.avg:.5f}")