-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathArgHandler.py
248 lines (211 loc) · 9.38 KB
/
ArgHandler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import time
import torch
from torch import nn
import CustomExceptions
from Nets import Utils
from Nets.ResGan import ResNetGenerator, ResNetDiscriminator
from Nets.SmallGan import Small_GAN
def handle_pretrained_generator(**kwargs):
if 'pretrained_generator' in kwargs:
if type(kwargs['pretrained_generator']) is bool:
return kwargs['pretrained_generator']
if kwargs['pretrained_generator'].lower() in ['true', 't', 'yes', 'y', '1']:
return True
elif kwargs['pretrained_generator'].lower() in ['false', 'f', 'no', 'n', '0']:
return False
else:
return False
def handle_pretrained_encoder(**kwargs):
if 'pretrained_encoder' in kwargs:
if type(kwargs['pretrained_encoder']) is bool:
return kwargs['pretrained_encoder']
if kwargs['pretrained_encoder'].lower() in ['true', 't', 'yes', 'y', '1']:
return True
elif kwargs['pretrained_encoder'].lower() in ['false', 'f', 'no', 'n', '0']:
return False
else:
return False
def handle_noise_size(**kwargs):
noise_size = -1
if 'noise_size' in kwargs:
try:
noise_size = int(kwargs['noise_size'])
if noise_size <= 0:
raise CustomExceptions.InvalidNoiseSizeError("noise_size must be greater than 0.")
except ValueError:
raise CustomExceptions.InvalidNoiseSizeError("noise_size must be a positive integer")
else:
raise CustomExceptions.InvalidNoiseSizeError("You have to set the noise_size argument")
return noise_size
def handle_num_epochs(**kwargs):
num_epochs = -1
if 'num_epochs' in kwargs:
try:
num_epochs = int(kwargs['num_epochs'])
if num_epochs <= 0:
raise CustomExceptions.NumEpochsError("The Number of epochs must be greater than 0")
except ValueError:
raise CustomExceptions.NumEpochsError("The Number of epochs must be a positive integer")
else:
raise CustomExceptions.NumEpochsError("The Number of epochs must be defined")
return num_epochs
def handle_batch_size(**kwargs):
batch_size = -1
if 'batch_size' in kwargs:
try:
batch_size = int(kwargs['batch_size'])
if batch_size <= 0:
raise CustomExceptions.BatchSizeError("The batch size must be greater than 0!")
except ValueError:
raise CustomExceptions.BatchSizeError("The batch size must be a positive integer")
else:
raise CustomExceptions.BatchSizeError("The batch size must be defined")
return batch_size
def handle_learning_rate(**kwargs):
if 'learning_rate' in kwargs:
try:
learning_rate = float(kwargs['learning_rate'])
except ValueError:
raise CustomExceptions.LearningRateError("The learning rate must be float")
else:
raise CustomExceptions.LearningRateError("The learning rate must be defined. Use 'learning_rate=0.0001' for example")
return learning_rate
def handle_criterion(**kwargs):
criterion = None
if 'criterion' in kwargs:
if kwargs['criterion'] == 'BCELoss':
criterion = nn.BCELoss()
elif kwargs['criterion'] == 'Wasserstein':
criterion = 'Wasserstein' # TODO
raise NotImplementedError
elif kwargs['criterion'] == 'MiniMax':
criterion = 'MiniMax' # TODO
raise NotImplementedError
else:
raise CustomExceptions.InvalidLossError()
else:
raise CustomExceptions.InvalidLossError()
return criterion
def handle_real_img_fake_label(**kwargs):
real_img_fake_label = False
if 'real_img_fake_label' in kwargs:
if type(kwargs['real_img_fake_label']) is bool:
return kwargs['real_img_fake_label']
if kwargs['real_img_fake_label'].lower() in ['true', 't', 'yes', 'y', '1']:
real_img_fake_label = True
elif kwargs['real_img_fake_label'].lower() in ['false', 'f', 'no', 'n', '0']:
real_img_fake_label = False
else:
raise CustomExceptions.InvalidArgumentError(
f'Invalid argument for "real_img_fake_label": "{kwargs["real_img_fake_label"]}"')
else:
print(f'No argument for "real_img_fake_label" found. Will use {real_img_fake_label}')
return real_img_fake_label
def handle_snapshot_settings(**kwargs):
snapshot_interval = -1
do_snapshots = False
if 'snapshot_interval' in kwargs:
try:
snapshot_interval = int(kwargs['snapshot_interval'])
do_snapshots = True
if snapshot_interval <= 0:
raise CustomExceptions.InvalidSnapshotInterval("snapshot_interval must be greater than 0.")
except ValueError:
raise CustomExceptions.InvalidNoiseSizeError("snapshot_interval must be a positive integer")
else:
do_snapshots = False
print(f'snapshot_interval is not defined. Will not create nor store snapshots')
return snapshot_interval, do_snapshots
def handle_name(**kwargs):
name = '#'
unique_name = '#'
if 'name' in kwargs:
name = kwargs['name']
unique_name = f'{name}_{time.strftime("%Y-%m-%d_%H-%M-%S")}'
if not name.isalnum():
raise CustomExceptions.InvalidNameError("name must only contain alphanumeric characters")
else:
unique_name = time.strftime("%Y-%m-%d_%H-%M-%S")
print("No name given. Will only use the time-stamp")
return name, unique_name
def handle_device(**kwargs):
device = None
if 'device' in kwargs:
if kwargs['device'] == "GPU":
if torch.cuda.is_available():
device = torch.device('cuda')
else:
raise CustomExceptions.GpuNotFoundError("Cannot find a CUDA device")
else:
device = torch.device('cpu')
return device
def handle_generator(NUM_CLASSES, N_IMAGE_CHANNELS, **kwargs):
device = handle_device(**kwargs)
noise_size = handle_noise_size(**kwargs)
generator = None
if 'generator' in kwargs:
if kwargs['generator'] == "small_gan":
generator = Small_GAN.GeneratorNet(noise_size=noise_size, num_classes=NUM_CLASSES,
n_image_channels=N_IMAGE_CHANNELS).to(device)
elif kwargs['generator'] == "res_net_depth1":
generator = ResNetGenerator.resnetGeneratorDepth1(noise_size + NUM_CLASSES, N_IMAGE_CHANNELS).to(device)
elif kwargs['generator'] == "res_net_depth2":
generator = ResNetGenerator.resnetGeneratorDepth2(noise_size + NUM_CLASSES, N_IMAGE_CHANNELS).to(device)
else:
raise CustomExceptions.NoGeneratorError(
f'The given generator net "{kwargs["generator"]}" cannot be found')
else:
raise CustomExceptions.NoGeneratorError("You need to define the generator net. keyword: 'generator'")
return generator
def handle_discriminator(NUM_CLASSES, N_IMAGE_CHANNELS, **kwargs):
device = handle_device(**kwargs)
discriminator = None
if 'discriminator' in kwargs:
if kwargs['discriminator'] == "small_gan":
discriminator = Small_GAN.DiscriminatorNet(n_image_channels=N_IMAGE_CHANNELS, num_classes=NUM_CLASSES).to(device)
elif kwargs['discriminator'] == "res_net_depth1":
discriminator = ResNetDiscriminator.resnetDiscriminatorDepth1(N_IMAGE_CHANNELS + NUM_CLASSES, 1).to(device)
elif kwargs['discriminator'] == "res_net_depth1_leaky":
discriminator = ResNetDiscriminator.resnetDiscriminatorDepth1Leaky(N_IMAGE_CHANNELS + NUM_CLASSES, 1).to(device)
elif kwargs['discriminator'] == "res_net_depth2":
discriminator = ResNetDiscriminator.resnetDiscriminatorDepth2(N_IMAGE_CHANNELS + NUM_CLASSES, 1).to(device)
else:
raise CustomExceptions.NoDiscriminatorError(
f'The given discriminator net "{kwargs["discriminator"]}" cannot be found')
else:
raise CustomExceptions.NoDiscriminatorError(
"You need to define the discriminator net. keyword: 'discriminator'")
return discriminator
def handle_model_path(**kwargs):
if 'model_path' in kwargs:
return kwargs['model_path']
else:
raise NotImplementedError('Using a default model_path is not implemented. And will never be')
def handle_output_path(**kwargs):
if 'output_path' in kwargs:
return kwargs['output_path']
else:
raise NotImplementedError('Using a default output_path is not implemented. And will never be')
def handle_weights_init(**kwargs):
if 'weights_init' in kwargs:
if kwargs['weights_init'] == "normal":
return Utils.weights_init
elif kwargs['weights_init'] == "xavier":
return Utils.weights_init_xavier
else: # default
print("Default normal weight init is used")
return Utils.weights_init
else: # default
print("Default normal weight init is used")
return Utils.weights_init
def handle_augmentation(**kwargs):
if 'augmentation' in kwargs:
if type(kwargs['augmentation']) is bool:
return kwargs['augmentation']
if kwargs['augmentation'].lower() in ['true', 't', 'yes', 'y', '1']:
return True
elif kwargs['augmentation'].lower() in ['false', 'f', 'no', 'n', '0']:
return False
else:
print("Won't use data augmentation.")
return False