Skip to content

Commit 5dc7cc1

Browse files
Recitation 11
1 parent 0bb19df commit 5dc7cc1

10 files changed

+597
-0
lines changed

recitation-11/Untitled.ipynb

+186
Large diffs are not rendered by default.
Binary file not shown.
Binary file not shown.

recitation-11/data/processed/test.pt

7.55 MB
Binary file not shown.
7.48 MB
Binary file not shown.
9.77 KB
Binary file not shown.
58.6 KB
Binary file not shown.
2.41 MB
Binary file not shown.

recitation-11/rbm_demo_utils.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os
2+
import numpy as np
3+
import torch
4+
import torch.utils
5+
import torchvision.datasets as datasets
6+
import torchvision.transforms as transforms
7+
from inferno.trainers.basic import Trainer
8+
import matplotlib.pyplot as plt
9+
import rbm_models
10+
11+
12+
class MNIST(torch.utils.data.Dataset):
13+
def __init__(self, max_len=-1, include_label=True):
14+
super().__init__()
15+
self.mnist = datasets.MNIST(
16+
root='./data', train=True,
17+
download=True, transform=transforms.ToTensor())
18+
self.max_len = max_len
19+
self.include_label = include_label
20+
21+
def __len__(self):
22+
if self.max_len < 0:
23+
return len(self.mnist)
24+
else:
25+
return self.max_len
26+
27+
def __getitem__(self, idx):
28+
(img, label) = self.mnist[idx]
29+
img = img.view(-1)
30+
img = rbm_models.discretize(rbm_models.rescale(img))
31+
if not self.include_label:
32+
return (img, label)
33+
label_onehot = img.new(10).fill_(0)
34+
label_onehot[label] = 1
35+
label_onehot = rbm_models.rescale(label_onehot)
36+
result = torch.cat([img, label_onehot], dim=0)
37+
return (result, label)
38+
39+
40+
class IdentityLoss(torch.nn.Module):
41+
def forward(self, x, _):
42+
return x
43+
44+
45+
class LossPrinter(torch.nn.Module):
46+
def __init__(self, criterion):
47+
super().__init__()
48+
self.criterion = criterion
49+
50+
def forward(self, *args, **kwargs):
51+
loss = self.criterion(*args, **kwargs)
52+
print("Loss: %f" % loss)
53+
return loss
54+
55+
56+
def train(net, dataset, criterion, num_epochs,
57+
batch_size, learn_rate, dir_name):
58+
dir_name = os.path.join('net/', dir_name)
59+
trainer = Trainer(net[0])
60+
61+
if (os.path.exists(os.path.join(dir_name, 'model.pytorch'))):
62+
net_temp = trainer.load_model(dir_name).model
63+
net[0].load_state_dict(net_temp.state_dict())
64+
print("Loaded checkpoint directly")
65+
else:
66+
if (not os.path.exists(dir_name)):
67+
os.makedirs(dir_name)
68+
data_loader = torch.utils.data.DataLoader(
69+
dataset, shuffle=True, batch_size=batch_size)
70+
net[0].train()
71+
72+
trainer \
73+
.build_criterion(LossPrinter(criterion)) \
74+
.bind_loader('train', data_loader) \
75+
.build_optimizer('Adam', lr=learn_rate) \
76+
.set_max_num_epochs(num_epochs)
77+
78+
if torch.cuda.is_available():
79+
trainer.cuda()
80+
81+
trainer.fit()
82+
trainer.save_model(dir_name)
83+
net[0].cpu()
84+
net[0].eval()
85+
86+
87+
def display_image(arr):
88+
width = int(np.sqrt(arr.size()[0]))
89+
label_onehot = arr[-10:]
90+
arr = (arr[:-10] + 1) / 2
91+
arr = arr.cpu().view(width, -1).numpy()
92+
plt.figure()
93+
plt.imshow(1.0 - arr, cmap='gray')
94+
_, pos = torch.max(label_onehot, 0)
95+
print(pos[0])
96+
97+
98+
def display_reconstruction(net, dataset):
99+
(image, _) = dataset[np.random.randint(len(dataset))]
100+
display_image(image)
101+
image = torch.autograd.Variable(image).unsqueeze(dim=0)
102+
reconst = net.decode(net.encode(image)).data[0]
103+
display_image(reconst)

0 commit comments

Comments
 (0)