-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprobabilistic_unet.py
346 lines (288 loc) · 16.1 KB
/
probabilistic_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
#This code is based on: https://github.com/SimonKohl/probabilistic_unet
from unet_blocks import *
from unet import Unet
from utils import init_weights,init_weights_orthogonal_normal, l2_regularisation
import torch.nn.functional as F
from torch.distributions import Normal, Independent, kl
import param
# 选择gpu,需要与train文件一致
device = param.device # 选择gpu
class Encoder(nn.Module):
"""
一个由len(num_filters)乘以一个no_convs_per_block卷积层的块组成的卷积神经网络,
在每个块之后执行池化操作; 并在每个卷积层之后,将应用非线性(ReLU)激活函数。
A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block convolutional layers,
after each block a pooling operation is performed. And after each convolutional layer a non-linear (ReLU) activation function is applied.
"""
def __init__(self, input_channels, num_filters, no_convs_per_block, initializers, padding=True, posterior=False):
super(Encoder, self).__init__()
self.contracting_path = nn.ModuleList()
self.input_channels = input_channels # 输入的通道数
self.num_filters = num_filters
if posterior:
#为了适应在通道轴上串联的遮罩,我们增加了input_channels。
#To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels.
self.input_channels += 1
layers = []
for i in range(len(self.num_filters)):
"""
确定此块中conv层的input_dim和output_dim。 第一层是输入*输出,
所有后续层都是输出*输出。
Determine input_dim and output_dim of conv layers in this block. The first layer is input x output,
All the subsequent layers are output x output.
"""
input_dim = self.input_channels if i == 0 else output_dim
output_dim = num_filters[i]
if i != 0:
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True))
layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=int(padding)))
layers.append(nn.ReLU(inplace=True))
for _ in range(no_convs_per_block-1):
layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=int(padding)))
layers.append(nn.ReLU(inplace=True))
self.layers = nn.Sequential(*layers)
self.layers.apply(init_weights)
def forward(self, input):
output = self.layers(input)
return output
class AxisAlignedConvGaussian(nn.Module):
"""
用轴对齐协方差矩阵对高斯分布进行参数化的卷积网路。
A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.
"""
def __init__(self, input_channels, num_filters, no_convs_per_block, latent_dim, initializers, posterior=False):
super(AxisAlignedConvGaussian, self).__init__()
self.input_channels = input_channels
self.channel_axis = 1
self.num_filters = num_filters
self.no_convs_per_block = no_convs_per_block
self.latent_dim = latent_dim
self.posterior = posterior
if self.posterior:
self.name = 'Posterior'
else:
self.name = 'Prior'
self.encoder = Encoder(self.input_channels, self.num_filters, self.no_convs_per_block, initializers, posterior=self.posterior)
self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1,1), stride=1)
self.show_img = 0
self.show_seg = 0
self.show_concat = 0
self.show_enc = 0
self.sum_input = 0
nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu')
nn.init.normal_(self.conv_layer.bias)
def forward(self, input, segm=None):
#If segmentation is not none, concatenate the mask to the channel axis of the input
if segm is not None:
self.show_img = input
self.show_seg = segm
input = torch.cat((input, segm), dim=1)
self.show_concat = input
self.sum_input = torch.sum(input)
encoding = self.encoder(input)
self.show_enc = encoding
#We only want the mean of the resulting hxw image
encoding = torch.mean(encoding, dim=2, keepdim=True)
encoding = torch.mean(encoding, dim=3, keepdim=True)
#Convert encoding to 2 x latent dim and split up for mu and log_sigma
mu_log_sigma = self.conv_layer(encoding)
#We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1
mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
mu_log_sigma = torch.squeeze(mu_log_sigma, dim=2)
mu = mu_log_sigma[:,:self.latent_dim]
log_sigma = mu_log_sigma[:,self.latent_dim:]
#This is a multivariate normal with diagonal covariance matrix sigma
#https://github.com/pytorch/pytorch/pull/11178
dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)),1)
return dist
class Fcomb(nn.Module):
"""
由no_convs_fcomb乘以1x1卷积组成的函数,该函数合并了从隐空间获取的样本,
沿它们的通道轴连接UNet和UNet的输出(特征图)。
A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space,
and output of the UNet (the feature map) by concatenating them along their channel axis.
"""
def __init__(self, num_filters, latent_dim, num_output_channels, num_classes, no_convs_fcomb, initializers, use_tile=True):
super(Fcomb, self).__init__()
self.num_channels = num_output_channels #output channels
self.num_classes = num_classes
self.channel_axis = 1
self.spatial_axes = [2,3]
self.num_filters = num_filters
self.latent_dim = latent_dim
self.use_tile = use_tile
self.no_convs_fcomb = no_convs_fcomb
self.name = 'Fcomb'
if self.use_tile:
layers = []
#Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer
layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1))
layers.append(nn.ReLU(inplace=True))
for _ in range(no_convs_fcomb-2):
layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1))
layers.append(nn.ReLU(inplace=True))
self.layers = nn.Sequential(*layers)
self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1)
if initializers['w'] == 'orthogonal':
self.layers.apply(init_weights_orthogonal_normal)
self.last_layer.apply(init_weights_orthogonal_normal)
else:
self.layers.apply(init_weights)
self.last_layer.apply(init_weights)
def tile(self, a, dim, n_tile):
"""
该功能取自PyTorch论坛,并模仿tf.tile的行为。
用来对张量(Tensor)进行扩展的,其特点是对当前张量内的数据进行一定规则的复制。最终的输出张量维度不变。
This function is taken form PyTorch forum and mimics the behavior of tf.tile.
Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
"""
init_dim = a.size(dim)
repeat_idx = [1] * a.dim()
repeat_idx[dim] = n_tile
a = a.repeat(*(repeat_idx))
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
return torch.index_select(a, dim, order_index)
def forward(self, feature_map, z):
"""
Z是batch_size * latent_dim,feature_map是batch_sizexno_channels * H * W。
因此,将Z广播到batch_size * latent_dim * H * W。 行为与tf.tile(已验证)完全相同
Z is batch_sizexlatent_dim and feature_map is batch_sizexno_channelsxHxW.
So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified)
"""
if self.use_tile:
z = torch.unsqueeze(z,2)
z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]])
z = torch.unsqueeze(z,3)
z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]])
#Concatenate the feature map (output of the UNet) and the sample taken from the latent space
feature_map = torch.cat((feature_map, z), dim=self.channel_axis)
output = self.layers(feature_map)
return self.last_layer(output)
class ProbabilisticUnet(nn.Module):
"""
概率UNet(https://arxiv.org/abs/1806.05034)实现。
input_channels:图像中的通道数(灰度为1,RGB为3)
num_classes:要预测的类数
num_filters:是过滤器层数的列表一致性
latent_dim:隐空间的维度
no_cons_per_block:先验和后验(卷积)编码器中的每个块卷积编号
A probabilistic UNet (https://arxiv.org/abs/1806.05034) implementation.
input_channels: the number of channels in the image (1 for greyscale and 3 for RGB)
num_classes: the number of classes to predict
num_filters: is a list consisint of the amount of filters layer
latent_dim: dimension of the latent space
no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior
"""
def __init__(self, input_channels=1, num_classes=1, num_filters=[32,64,128,192], latent_dim=6, no_convs_fcomb=4, beta=10.0):
super(ProbabilisticUnet, self).__init__()
self.input_channels = input_channels # 输入图像通道数
self.num_classes = num_classes # 分割类别数
self.num_filters = num_filters # filter数
self.latent_dim = latent_dim # 隐空间维度
self.no_convs_per_block = 3
self.no_convs_fcomb = no_convs_fcomb
self.initializers = {'w':'he_normal', 'b':'normal'} # 初始化
self.beta = beta
self.z_prior_sample = 0
self.unet = Unet(self.input_channels,
self.num_classes,
self.num_filters,
self.initializers,
apply_last_layer=False,
padding=True).to(device)
self.prior = AxisAlignedConvGaussian(self.input_channels,
self.num_filters,
self.no_convs_per_block,
self.latent_dim,
self.initializers,).to(device)
self.posterior = AxisAlignedConvGaussian(self.input_channels,
self.num_filters,
self.no_convs_per_block,
self.latent_dim,
self.initializers,
posterior=True).to(device)
self.fcomb = Fcomb(self.num_filters,
self.latent_dim,
self.input_channels,
self.num_classes,
self.no_convs_fcomb,
{'w':'orthogonal', 'b':'normal'},
use_tile=True).to(device)
def forward(self, patch, segm, training=True):
"""
为patch构建先验隐空间,并通过UNet运行patch,
如果training=True,则还可以构造后方潜在空间
Construct prior latent space for patch and run patch through UNet,
in case training is True also construct posterior latent space
"""
if training:
self.posterior_latent_space = self.posterior.forward(patch, segm)
self.prior_latent_space = self.prior.forward(patch)
self.unet_features = self.unet.forward(patch,False)
def sample(self, testing=False):
"""
通过根据先验样本进行重构来对切割进行采样
并将其与UNet特征相结合
Sample a segmentation by reconstructing from a prior sample
and combining this with UNet features
"""
if testing == False:
z_prior = self.prior_latent_space.rsample()
self.z_prior_sample = z_prior
else:
#你可以选择是指样本还是平均值。 对于GED,取样非常重要。
#You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample.
#z_prior = self.prior_latent_space.base_dist.loc
z_prior = self.prior_latent_space.sample()
self.z_prior_sample = z_prior
return self.fcomb.forward(self.unet_features,z_prior)
def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None):
"""
从后验样本(解码后验样本)和UNet特征图重建分割
use_posterior_mean:使用posterior_mean代替对z_q的采样
compute_posterior:使用提供的样本或来自后潜在空间的样本
Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map
use_posterior_mean: use posterior_mean instead of sampling z_q
calculate_posterior: use a provided sample or sample from posterior latent space
"""
if use_posterior_mean:
z_posterior = self.posterior_latent_space.loc
else:
if calculate_posterior:
z_posterior = self.posterior_latent_space.rsample()
return self.fcomb.forward(self.unet_features, z_posterior)
def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None):
"""
计算后验KL(Q||P)和先验KL(Q||P)之间的KL散度
分析:通过分析或通过后验采样来计算KL
compute_posterior:如果我们使用samapling来近似KL,则可以在此处采样或提供样本
KL散度:用来衡量两个概率分布之间的差异
Calculate the KL divergence between the posterior and prior KL(Q||P)
analytic: calculate KL analytically or via sampling from the posterior
calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample
"""
if analytic:
#Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545
kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)
else:
if calculate_posterior:
z_posterior = self.posterior_latent_space.rsample()
log_posterior_prob = self.posterior_latent_space.log_prob(z_posterior)
log_prior_prob = self.prior_latent_space.log_prob(z_posterior)
kl_div = log_posterior_prob - log_prior_prob
return kl_div
def elbo(self, segm, analytic_kl=True, reconstruct_posterior_mean=False):
"""
计算P(Y|X)的边际似然函数下界
Calculate the evidence lower bound of the log-likelihood of P(Y|X)
"""
criterion = nn.BCEWithLogitsLoss(size_average = False, reduce=False, reduction=None)
# criterion = nn.BCEWithLogitsLoss(reduction=None)
z_posterior = self.posterior_latent_space.rsample()
self.kl = torch.mean(self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior))
#Here we use the posterior sample sampled above
self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, calculate_posterior=False, z_posterior=z_posterior)
reconstruction_loss = criterion(input=self.reconstruction, target=segm)
self.reconstruction_loss = torch.sum(reconstruction_loss)
self.mean_reconstruction_loss = torch.mean(reconstruction_loss)
return -(self.reconstruction_loss + self.beta * self.kl)