|
| 1 | +import os |
| 2 | +import numpy as np |
| 3 | +import torch |
| 4 | +import torch.utils |
| 5 | +import torchvision.datasets as datasets |
| 6 | +import torchvision.transforms as transforms |
| 7 | +from inferno.trainers.basic import Trainer |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +import rbm_models |
| 10 | + |
| 11 | + |
| 12 | +class MNIST(torch.utils.data.Dataset): |
| 13 | + def __init__(self, max_len=-1, include_label=True): |
| 14 | + super().__init__() |
| 15 | + self.mnist = datasets.MNIST( |
| 16 | + root='./data', train=True, |
| 17 | + download=True, transform=transforms.ToTensor()) |
| 18 | + self.max_len = max_len |
| 19 | + self.include_label = include_label |
| 20 | + |
| 21 | + def __len__(self): |
| 22 | + if self.max_len < 0: |
| 23 | + return len(self.mnist) |
| 24 | + else: |
| 25 | + return self.max_len |
| 26 | + |
| 27 | + def __getitem__(self, idx): |
| 28 | + (img, label) = self.mnist[idx] |
| 29 | + img = img.view(-1) |
| 30 | + img = rbm_models.discretize(rbm_models.rescale(img)) |
| 31 | + if not self.include_label: |
| 32 | + return (img, label) |
| 33 | + label_onehot = img.new(10).fill_(0) |
| 34 | + label_onehot[label] = 1 |
| 35 | + label_onehot = rbm_models.rescale(label_onehot) |
| 36 | + result = torch.cat([img, label_onehot], dim=0) |
| 37 | + return (result, label) |
| 38 | + |
| 39 | + |
| 40 | +class IdentityLoss(torch.nn.Module): |
| 41 | + def forward(self, x, _): |
| 42 | + return x |
| 43 | + |
| 44 | + |
| 45 | +class LossPrinter(torch.nn.Module): |
| 46 | + def __init__(self, criterion): |
| 47 | + super().__init__() |
| 48 | + self.criterion = criterion |
| 49 | + |
| 50 | + def forward(self, *args, **kwargs): |
| 51 | + loss = self.criterion(*args, **kwargs) |
| 52 | + print("Loss: %f" % loss) |
| 53 | + return loss |
| 54 | + |
| 55 | + |
| 56 | +def train(net, dataset, criterion, num_epochs, |
| 57 | + batch_size, learn_rate, dir_name): |
| 58 | + dir_name = os.path.join('net/', dir_name) |
| 59 | + trainer = Trainer(net[0]) |
| 60 | + |
| 61 | + if (os.path.exists(os.path.join(dir_name, 'model.pytorch'))): |
| 62 | + net_temp = trainer.load_model(dir_name).model |
| 63 | + net[0].load_state_dict(net_temp.state_dict()) |
| 64 | + print("Loaded checkpoint directly") |
| 65 | + else: |
| 66 | + if (not os.path.exists(dir_name)): |
| 67 | + os.makedirs(dir_name) |
| 68 | + data_loader = torch.utils.data.DataLoader( |
| 69 | + dataset, shuffle=True, batch_size=batch_size) |
| 70 | + net[0].train() |
| 71 | + |
| 72 | + trainer \ |
| 73 | + .build_criterion(LossPrinter(criterion)) \ |
| 74 | + .bind_loader('train', data_loader) \ |
| 75 | + .build_optimizer('Adam', lr=learn_rate) \ |
| 76 | + .set_max_num_epochs(num_epochs) |
| 77 | + |
| 78 | + if torch.cuda.is_available(): |
| 79 | + trainer.cuda() |
| 80 | + |
| 81 | + trainer.fit() |
| 82 | + trainer.save_model(dir_name) |
| 83 | + net[0].cpu() |
| 84 | + net[0].eval() |
| 85 | + |
| 86 | + |
| 87 | +def display_image(arr): |
| 88 | + width = int(np.sqrt(arr.size()[0])) |
| 89 | + label_onehot = arr[-10:] |
| 90 | + arr = (arr[:-10] + 1) / 2 |
| 91 | + arr = arr.cpu().view(width, -1).numpy() |
| 92 | + plt.figure() |
| 93 | + plt.imshow(1.0 - arr, cmap='gray') |
| 94 | + _, pos = torch.max(label_onehot, 0) |
| 95 | + print(pos[0]) |
| 96 | + |
| 97 | + |
| 98 | +def display_reconstruction(net, dataset): |
| 99 | + (image, _) = dataset[np.random.randint(len(dataset))] |
| 100 | + display_image(image) |
| 101 | + image = torch.autograd.Variable(image).unsqueeze(dim=0) |
| 102 | + reconst = net.decode(net.encode(image)).data[0] |
| 103 | + display_image(reconst) |
0 commit comments