-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_model.py
108 lines (95 loc) · 4.21 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from load_Brain_data import BrainS18Dataset
from probabilistic_unet import ProbabilisticUnet
from utils import l2_regularisation, label2multichannel
from save_load_net import save_model, load_model
from evaluate import evaluate
import param
# 参数
class_num = param.class_num # 选择分割类别数
epochs = 128 # 训练周期
learning_rate = 1e-4 # 学习率
latent_dim = 6 # 隐空间维度
train_batch_size = 16 # 训练
test_batch_size = 1 # 预测
model_name = 'punet_e128_c9_ld6_f.pt' # 待保存的模型名,epoch,patient,classnum_latentdim
device = param.device # 选择cpu
# 打印记录训练超参数
print("类别数:{}\nEpoch:{}\nLearning_rate:{}\nlatent_dim:{}".format(class_num,
epochs,
learning_rate,
latent_dim))
print("待保存模型名称: {}".format(model_name))
# 数据集
dataset = BrainS18Dataset(root_dir='data/BrainS18',
folders=['1_img', '4_img', '5_img', '7_img', '14_img', '148_img'],
class_num=class_num,
file_names=['_reg_T1.png', '_segm.png'])
# 数据划分并设置sampler((固定训练集和测试集))
dataset_size = len(dataset) # 数据集大小
# split = param.split
# indices = param.indices
# train_indices, test_indices = indices[split:], indices[:split] # 用上述所有数据训练
train_indices = list(range(dataset_size))
train_sampler = SubsetRandomSampler(train_indices)
# test_sampler = SubsetRandomSampler(test_indices)
# 数据加载器
train_loader = DataLoader(dataset, batch_size=train_batch_size, sampler=train_sampler) # 训练
train_eval_loader = DataLoader(dataset, batch_size=test_batch_size, sampler=train_sampler) # 评估
# test_loader = DataLoader(dataset, batch_size=test_batch_size, sampler=test_sampler) # 评估
# print("Number of training/test patches:", (len(train_indices),len(test_indices)))
print("Number of training patches: {}".format(len(train_indices)))
# 网络模型
net = ProbabilisticUnet(input_channels=1,
num_classes=class_num,
num_filters=[32,64,128,192],
latent_dim=latent_dim,
no_convs_fcomb=4,
beta=10.0)
net.to(device)
# 优化器
optimizer = torch.optim.Adam(net.parameters(),
lr=learning_rate,
weight_decay=0)
# 训练模型并保存
try:
# 训练
for epoch in range(epochs):
print("Epoch {}".format(epoch))
# 训练
net.train()
losses = 0 # 计算平均loss值
for step, (patch, mask, _) in enumerate(train_loader):
patch = patch.to(device)
mask = mask.to(device)
# mask = torch.unsqueeze(mask,1) (batch_size,240,240)->(batch_size,1,240,240)
net.forward(patch, mask, training=True)
# label通道数1->9,单通道(1-9)变多通道(0/1)
mask = label2multichannel(mask.cpu(), class_num)
mask = mask.to(device)
elbo = net.elbo(mask)
###
reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) + l2_regularisation(net.fcomb.layers)
loss = -elbo + 1e-5 * reg_loss
losses += loss
if step%10 == 0:
print("-- [step {}] loss: {}".format(step, loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("-- [step {}] loss: {}".format(step, loss))
# 评估
losses /= (step+1)
print("Loss (Train): {}".format(losses))
evaluate(net, train_eval_loader, device, class_num, test=False)
except KeyboardInterrupt as e:
print('KeyboardInterrupt: {}'.format(e))
except Exception as e:
print('Exception: {}'.format(e))
finally:
# 保存模型
print("saving the trained net model -- {}".format(model_name))
save_model(net, path='model/{}'.format(model_name))