This repository has been archived by the owner on Nov 18, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 93
/
Copy pathfine-tune.py
97 lines (89 loc) · 3.7 KB
/
fine-tune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import argparse
import logging
logging.basicConfig(level=logging.DEBUG)
from common import find_mxnet
from common import data, fit, modelzoo
import mxnet as mx
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name):
"""
symbol: the pre-trained network symbol
arg_params: the argument parameters of the pre-trained model
num_classes: the number of classes for the fine-tune datasets
layer_name: the layer name before the last fully-connected layer
"""
all_layers = sym.get_internals()
net = all_layers[layer_name+'_output']
net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc')
net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
new_args = dict({k:arg_params[k] for k in arg_params if 'fc' not in k})
return (net, new_args)
def _save_model(args, rank=0):
if args.model_prefix is None:
return None
dst_dir = os.path.dirname(args.model_prefix)
if not os.path.isdir(dst_dir):
os.mkdir(dst_dir)
return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % (
args.model_prefix, rank))
def _get_lr_scheduler(args, kv):
if 'lr_factor' not in args or args.lr_factor >= 1:
return (args.lr, None)
epoch_size = args.num_examples / args.batch_size
if 'dist' in args.kv_store:
epoch_size /= kv.num_workers
begin_epoch = args.load_epoch if args.load_epoch else 0
step_epochs = [int(l) for l in args.lr_step_epochs.split(',')]
lr = args.lr
for s in step_epochs:
if begin_epoch >= s:
lr *= args.lr_factor
if lr != args.lr:
logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch))
steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0]
return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor))
if __name__ == "__main__":
# parse args
parser = argparse.ArgumentParser(description="fine-tune a dataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
train = fit.add_fit_args(parser)
data.add_data_args(parser)
aug = data.add_data_aug_args(parser)
parser.add_argument('--pretrained-model', type=str,
help='the pre-trained model')
parser.add_argument('--layer-before-fullc', type=str, default='flatten0',
help='the name of the layer before the last fullc layer')
# use less augmentations for fine-tune
data.set_data_aug_level(parser, 3)
# use a small learning rate and less regularizations
parser.set_defaults(
# network
network = 'inception-bn',
# data
num_classes = 1000,
num_examples = 1281167,
image_shape = '3,224,224',
min_random_scale = 1, # if input image has min size k, suggest to use
# 256.0/x, e.g. 0.533 for 480
# train
num_epochs = 60,
lr_step_epochs = '20,30,40,50',
lr = 0.01,
batch_size = 32,
optimizer = 'sgd',
disp_batches = 10,
top_k = 5,
data_train = '/data/imagenet1k/imagenet1k-train',
data_val = '/data/imagenet1k/imagenet1k-val'
)
args = parser.parse_args()
#load pretrained model
if args.pretrained_model:
sym, args_params, aux_params = mx.model.load_checkpoint(args.pretrained_model, 126)#inception-bn
# train
fit.fit(args = args,
network = sym,
data_loader = data.get_rec_iter,
arg_params = args_params,
aux_params = aux_params,
)