-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
53 lines (37 loc) · 1.6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import sys
import hpconfig as cfg
import pathconfig
sys.path.append(pathconfig.sys_path)
from fastai.conv_learner import *
from model import get_learner
class EarlyStopping(Callback):
def __init__(self, savename, learn):
self.best_dice = 0
self.savename = savename
self.learn = learn
def on_epoch_end(self, metrics):
if(metrics[2] > self.best_dice):
self.best_dice = metrics[2]
print(f'\nNew highest dice achieved: {metrics[2]}, saving to {self.savename}')
self.learn.save(self.savename)
def main():
cfg.print_hps()
# Currently support "resnet34" and "densenet121"
learn = get_learner(cfg.arch)
learn.freeze_to(1)
lr=cfg.seq_lrs[0]
wd=cfg.seq_wds[0]
lrs = np.array([lr/(cfg.lrs_scalings[0] ** 2),lr/cfg.lrs_scalings[0],lr])
# learn.fit(lrs,1,wds=cfg.seq_wds[0],cycle_len=cfg.cycle_lens[0],use_clr=cfg.clrs[0], callbacks=[EarlyStopping(cfg.save_name + str(0), learn)])
learn.fit(lrs,1,wds=1e-7,cycle_len=20,use_clr=(5,8), callbacks=[EarlyStopping('utferror0', learn)])
learn.load(cfg.save_name + str(0))
learn.unfreeze()
learn.bn_freeze(True)
learn.fit(lrs/4, 1, wds=wd, cycle_len=cfg.cycle_lens[1], use_clr=cfg.clrs[1], callbacks=[EarlyStopping(cfg.save_name + str(1), learn)])
learn.load(cfg.save_name + str(0))
lr=cfg.seq_lrs[1]
wd=cfg.seq_wds[1]
lrs = np.array([lr/(cfg.lrs_scalings[1] ** 2),lr/cfg.lrs_scalings[1],lr])
learn.fit(lrs, 1, wds=wd, cycle_len=cfg.cycle_lens[2], callbacks=[EarlyStopping(cfg.save_name + str(2), learn)])
if __name__ == '__main__':
main()