-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodule10.html
365 lines (241 loc) · 10.5 KB
/
module10.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
<!DOCTYPE html>
<html>
<head>
<title>Unsupervised Learning - Generative Adversarial Networks [Marc Lelarge]</title>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8"/>
<link rel="stylesheet" href="./assets/katex.min.css">
<link rel="stylesheet" type="text/css" href="./assets/slides.css">
<link rel="stylesheet" type="text/css" href="./assets/grid.css">
</head>
<body>
<textarea id="source">
class: center, middle, title-slide
count: false
# Module 10
## Unsupervised learning
## Generative Adversarial Networks
<br/><br/>
.bold[Marc Lelarge]
---
# Overview of the course:
1- .grey[Course overview: machine learning pipeline]
2- .grey[PyTorch tensors and automatic differentiation]
3- .grey[Classification with deep learning]
4- .grey[Convolutional neural networks]
5- .grey[Embedding layers and dataloaders]
6- Unsupervised learning: auto-encoders and generative adversarial networks
* .red[Deep learning architectures] - in PyTorch: recap!
---
# Generative Adversarial Networks (GAN)
## practicals: implementing conditional GAN and (simplified) InfoGAN
---
# Generative Adversarial Networks (GAN)
## Learning high-dimension generative models
The idea behing GANS is to train two networks jointly:
* a discriminator $\mathbf{D}$ to classify samples as "real" or "fake"
* a generator $\mathbf{G}$ to map a fixed distribution to samples that fool $\mathbf{D}$
.center[
<img src="images/module10/GAN_archi.png" style="width: 700px;" />
]
.credit[
Goodfellow et al. [Generative adversarial nets](https://arxiv.org/abs/1406.2661). 2014.
]
---
# GAN learning
The discriminator $\mathbf{D}$ is a classifier and $\mathbf{D}(x)$ is interpreted as the probability for $x$ to be a real sample.
The generator $\mathbf{G}$ takes as input a Gaussian random variable $z$ and produces a fake sample $\mathbf{G}(z)$.
The discriminator and the generator are learned alternatively, i.e. when parameters of $\mathbf{D}$ are learned $\mathbf{G}$ is fixed and vice versa.
--
When $\mathbf{G}$ is fixed, the learning of $\mathbf{D}$ is the standard learning process of a binary classifier (Sigmoid layer + BCE loss).
--
count: false
The learning of $\mathbf{G}$ is more subtle. The performance of $\mathbf{G}$ is evaluated thanks to the discriminator $\mathbf{D}$, i.e. the generator .red[maximizes] the loss of the discriminator.
---
# Learning of $\mathbf{D}$
The task of $\mathbf{D}$ is to distinguish real points $x_1,\dots, x_N$ from generated points $\mathbf{G}(z_1),\dots, \mathbf{G}(z_N)$.
The last layer of $\mathbf{D}$ is a Sigmoid layer, then learning of $\mathbf{D}$ is done thanks to the binary cross-entropy loss:
$$
\mathcal{L}(\mathbf{D},\mathbf{G}) = -\sum_{n=1}^N \log \mathbf{D}(x_n)+\log \left( 1-\mathbf{D}(\mathbf{G}(z_n))\right).
$$
--
count: false
For a fixed generator $\mathbf{G}$, the optimal discriminator is
$$
\mathbf{D}^* = \arg\min\mathcal{L}(\mathbf{D},\mathbf{G}).
$$
---
# Learning of $\mathbf{G}$
The task of $\mathbf{G}$ is to fool the discriminator.
For a fixed discriminator $\mathbf{D}$, the optimal generator is
$$
\mathbf{G}^* = \color{red}{\arg\max} \mathcal{L}(\mathbf{D},\mathbf{G})= \arg\max -\sum_{n=1}^N \log \left( 1-\mathbf{D}(\mathbf{G}(z_n))\right).
$$
In practice, the loss for $\mathbf{G}$ is often replaced by:
$$
\mathbf{G}^* = \arg\max \sum_{n=1}^N \log \left(\mathbf{D}(\mathbf{G}(z_n))\right).
$$
---
# Loss function for $\mathbf{G}$
.center[
<img src="images/module10/loss_G.png" style="width: 500px;" />
]
When the generator is weak compared to the discriminator, i.e. when $\mathbf{D}(\mathbf{G}(z)<<1$, the modified loss boosts the learning of the generator thanks to the high slope of $\log$ around zero.
---
# GAN for 2-d point cloud
Creating the generator and discriminator.
```
import torch
import torch.nn as nn
z_dim = 32
hidden_dim = 128
net_G = nn.Sequential(nn.Linear(z_dim,hidden_dim),
nn.ReLU(), nn.Linear(hidden_dim, 2))
net_D = nn.Sequential(nn.Linear(2,hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim,1),
nn.Sigmoid())
```
The point cloud will be given as a numpy array X.
.center[
<img src="images/module10/double_moon.png" style="width: 300px;" />
]
---
```
batch_size, lr = 50, 1e-3
nb_epochs = 500
optimizer_G = torch.optim.Adam(net_G.parameters(),lr=lr)
optimizer_D = torch.optim.Adam(net_D.parameters(),lr=lr)
for e in range(nb_epochs):
np.random.shuffle(X)
real_samples = torch.from_numpy(X).type(torch.FloatTensor)
for real_batch in real_samples.split(batch_size):
#improving D
z = torch.empty(batch_size,z_dim).normal_()
fake_batch = net_G(z)
D_scores_on_real = net_D(real_batch)
D_scores_on_fake = net_D(fake_batch)
loss = -torch.mean(torch.log(1-D_scores_on_fake) + torch.log(D_scores_on_real))
optimizer_D.zero_grad()
loss.backward()
optimizer_D.step()
# improving G
z = torch.empty(batch_size,z_dim).normal_()
fake_batch = net_G(z)
D_scores_on_fake = net_D(fake_batch)
loss = -torch.mean(torch.log(D_scores_on_fake))
optimizer_G.zero_grad()
loss.backward()
optimizer_G.step()
```
---
# Loss curves
.center[
<img src="images/module10/losses.png" style="width: 600px;" />
]
Contrary to standard loss minimization, we have no guarantee here that the network will stabilize. It can very well oscillate without convergence.
---
#GAN in action
.center[
<img src="images/module10/GAN_action.gif" style="width: 600px;" />
]
A GAN fitting double moons.
---
# Mode collapse
The goal of a GAN is to find the best generator against any discriminator, i.e.
$
\mathbf{G}^* = \arg\max \min_D \mathcal{L}(\mathbf{D},\mathbf{G}).
$
But optimization alternates between finding the best generator against the current discriminator and finding the best discriminator against the current generator. In particular, we do not solve the $\max\min$ problem but alternate between the two problems.
--
As a result, a possible equilibrium of the game observed in practice, is for the generator to generates only 'easy' samples, i.e. those most difficult to classify for the discriminator.
---
# Mode collapse on MNIST
.center[
<img src="images/module10/fake_images-200.png" style="width: 250px;" />
]
Those generated digits are clearly not following the origninal distribution of MNIST dataset and ones for example seems over-represented.
---
# Conditional GAN
.center[
<img src="images/module10/CGAN.png" style="width: 500px;" />
]
When labels are available, mode collapse can be mitigated (more details in practicals).
.footnote[Mirza et al. [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784). 2014.]
---
# InfoGAN
.center[
<img src="images/module10/IGAN.png" style="width: 500px;" />
]
Even when labels are not available, mode collapse can be mitigated (more details in practicals).
.footnote[Chen et al. [InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets](https://arxiv.org/abs/1606.03657). 2016.]
---
#IGAN in action
.center[
<img src="images/module10/IGAN_action.gif" style="width: 700px;" />
]
An InfoGAN fitting double moons.
---
# Deep Convolutional GAN
"Historical attempts to scale up GANs using CNNs to model images have been unsuccessful. (...) We also encountered difficulties attempting to scale GANs using CNN architectures commonly used in the supervised literature. However, after extensive model exploration we identified a family of architectures that resulted in stable training across a range of datasets and allowed for training higher resolution and deeper generative models."
.footnote[Radford et al. [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434). 2015.]
---
# DC-GAN rules
- replace pooling layers with strided convolutions in $\mathbf{D}$ and strided transposed convolution in $\mathbf{G}$.
- use batchnorm in both $\mathbf{D}$ and $\mathbf{G}$.
- remove fully connected hidden layers.
- use `ReLU` in $\mathbf{G}$ except for the output using `Tanh`.
- use `LeakyReLU` in $\mathbf{D}$ for all layers.
---
# .grey[Generative Adversarial Networks (GAN)]
## [practicals: implementing conditional GAN and (simplified) InfoGAN](https://github.com/dataflowr/notebooks/blob/master/Module10/10_GAN_double_moon.ipynb)
---
class: end-slide, center
count: false
The end.
</textarea>
<script src="./assets/remark-latest.min.js"></script>
<script src="./assets/auto-render.min.js"></script>
<script src="./assets/katex.min.js"></script>
<script type="text/javascript">
function getParameterByName(name, url) {
if (!url) url = window.location.href;
name = name.replace(/[\[\]]/g, "\\$&");
var regex = new RegExp("[?&]" + name + "(=([^&#]*)|&|#|$)"),
results = regex.exec(url);
if (!results) return null;
if (!results[2]) return '';
return decodeURIComponent(results[2].replace(/\+/g, " "));
}
//var options = {sourceUrl: getParameterByName("p"),
// highlightLanguage: "python",
// // highlightStyle: "tomorrow",
// // highlightStyle: "default",
// highlightStyle: "github",
// // highlightStyle: "googlecode",
// // highlightStyle: "zenburn",
// highlightSpans: true,
// highlightLines: true,
// ratio: "16:9"};
var options = {sourceUrl: getParameterByName("p"),
highlightLanguage: "python",
highlightStyle: "github",
highlightSpans: true,
highlightLines: true,
ratio: "16:9",
slideNumberFormat: (current, total) => `
<div class="progress-bar-container">${current}/${total} <br/><br/>
<div class="progress-bar" style="width: ${current/total*100}%"></div>
</div>
`};
var renderMath = function() {
renderMathInElement(document.body, {delimiters: [ // mind the order of delimiters(!?)
{left: "$$", right: "$$", display: true},
{left: "$", right: "$", display: false},
{left: "\\[", right: "\\]", display: true},
{left: "\\(", right: "\\)", display: false},
]});
}
var slideshow = remark.create(options, renderMath);
</script>
</body>
</html>