-
Notifications
You must be signed in to change notification settings - Fork 68
/
Copy pathmain.py
executable file
·125 lines (97 loc) · 5.42 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from SPADE import SPADE
import argparse
from utils import *
"""parsing and configuration"""
def parse_args():
desc = "Tensorflow implementation of SPADE"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--phase', type=str, default='train', choices=('train', 'guide', 'random'), help='phase name')
parser.add_argument('--dataset', type=str, default='spade_celebA', help='dataset_name')
parser.add_argument('--epoch', type=int, default=100, help='The number of epochs to run')
parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
# The total number of iterations is [epoch * iteration]
parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq')
parser.add_argument('--save_freq', type=int, default=50000, help='The number of ckpt_save_freq')
parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag')
parser.add_argument('--decay_epoch', type=int, default=50, help='decay epoch')
parser.add_argument('--lr', type=float, default=0.0002, help='The learning rate')
parser.add_argument('--TTUR', type=str2bool, default=True, help='Use TTUR training scheme')
parser.add_argument('--num_style', type=int, default=3, help='number of styles to sample')
parser.add_argument('--guide_img', type=str, default='guide.jpg', help='Style guided image translation')
parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda')
parser.add_argument('--adv_weight', type=int, default=1, help='Weight about GAN')
parser.add_argument('--vgg_weight', type=int, default=10, help='Weight about perceptual loss')
parser.add_argument('--feature_weight', type=int, default=10, help='Weight about discriminator feature matching loss')
parser.add_argument('--kl_weight', type=float, default=0.05, help='Weight about kl-divergence')
parser.add_argument('--gan_type', type=str, default='hinge', help='gan / lsgan / hinge / wgan-gp / wgan-lp / dragan')
parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
parser.add_argument('--n_dis', type=int, default=4, help='The number of discriminator layer')
parser.add_argument('--n_scale', type=int, default=2, help='number of scales')
parser.add_argument('--n_critic', type=int, default=1, help='The number of critic')
parser.add_argument('--sn', type=str2bool, default=True, help='using spectral norm')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
parser.add_argument('--num_upsampling_layers', type=str, default='more',
choices=('normal', 'more', 'most'),
help="If 'more', adds upsampling layer between the two middle resnet blocks. "
"If 'most', also add one more upsampling + resnet layer at the end of the generator")
parser.add_argument('--img_height', type=int, default=256, help='The height size of image')
parser.add_argument('--img_width', type=int, default=256, help='The width size of image ')
parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
parser.add_argument('--segmap_ch', type=int, default=3, help='The size of segmap channel')
parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
help='Directory name to save the checkpoints')
parser.add_argument('--result_dir', type=str, default='results',
help='Directory name to save the generated images')
parser.add_argument('--log_dir', type=str, default='logs',
help='Directory name to save training logs')
parser.add_argument('--sample_dir', type=str, default='samples',
help='Directory name to save the samples on training')
return check_args(parser.parse_args())
"""checking arguments"""
def check_args(args):
# --checkpoint_dir
check_folder(args.checkpoint_dir)
# --result_dir
check_folder(args.result_dir)
# --log_dir
check_folder(args.log_dir)
# --sample_dir
check_folder(args.sample_dir)
# --epoch
try:
assert args.epoch >= 1
except:
print('number of epochs must be larger than or equal to one')
# --batch_size
try:
assert args.batch_size >= 1
except:
print('batch size must be larger than or equal to one')
return args
"""main"""
def main():
# parse arguments
args = parse_args()
if args is None:
exit()
# open session
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
gan = SPADE(sess, args)
# build graph
gan.build_model()
# show network architecture
show_all_variables()
if args.phase == 'train' :
gan.train()
print(" [*] Training finished!")
if args.phase == 'random' :
gan.random_test()
print(" [*] Random test finished!")
if args.phase == 'guide' :
gan.guide_test()
print(" [*] Guide test finished")
if __name__ == '__main__':
main()