-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
346 lines (275 loc) · 14.9 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import itertools
import os
import time
import argparse
import json
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DistributedSampler, DataLoader
import torch.multiprocessing as mp
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel
from env import AttrDict, build_env
from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
from models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, MultiResSpecDiscriminator,\
CoMBD, SBD, discriminator_loss, feature_loss, generator_loss
from pqmf import PQMF
from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
from stft import TorchSTFT
torch.backends.cudnn.benchmark = True
def get_n_params(model):
pp=0
for p in list(model.parameters()):
nn=1
for s in list(p.size()):
nn = nn*s
pp += nn
return pp
def train(rank, a, h):
if h.num_gpus > 1:
init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
torch.cuda.manual_seed(h.seed)
device = torch.device('cuda:{:d}'.format(rank))
subbands2, taps2, cutoff_ratio2, beta2 = h.pqmf_config["lv2"]
subbands1, taps1, cutoff_ratio1, beta1 = h.pqmf_config["lv2"]
pqmf_lv2 = PQMF(subbands2, taps2, cutoff_ratio2, beta2)
pqmf_lv1 = PQMF(subbands1, taps1, cutoff_ratio1, beta1)
generator = Generator(h).to(device)
combd = CoMBD(h.combd, [pqmf_lv2, pqmf_lv1]).to(device)
sbd = SBD(h["sbd"]).to(device)
if h.use_mrsd:
print("Use mrsd...")
msd = MultiResSpecDiscriminator().to(device)
else:
msd = MultiScaleDiscriminator().to(device)
stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft, device=device).to(device)
if rank == 0:
print(generator)
print("Size")
print(get_n_params(generator))
os.makedirs(a.checkpoint_path, exist_ok=True)
print("checkpoints directory : ", a.checkpoint_path)
if os.path.isdir(a.checkpoint_path):
cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
steps = 0
if cp_g is None or cp_do is None:
state_dict_do = None
last_epoch = -1
else:
state_dict_g = load_checkpoint(cp_g, device)
state_dict_do = load_checkpoint(cp_do, device)
generator.load_state_dict(state_dict_g['generator'])
combd.load_state_dict(state_dict_do['combd'])
sbd.load_state_dict(state_dict_do['sbd'])
msd.load_state_dict(state_dict_do['msd'])
steps = state_dict_do['steps'] + 1
last_epoch = state_dict_do['epoch']
if h.num_gpus > 1:
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
combd = DistributedDataParallel(combd, device_ids=[rank]).to(device)
sbd = DistributedDataParallel(sbd, device_ids=[rank]).to(device)
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), sbd.parameters(), combd.parameters()),
h.learning_rate, betas=[h.adam_b1, h.adam_b2])
if state_dict_do is not None:
optim_g.load_state_dict(state_dict_do['optim_g'])
optim_d.load_state_dict(state_dict_do['optim_d'])
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
training_filelist, validation_filelist = get_dataset_filelist(a)
trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
sampler=train_sampler,
batch_size=h.batch_size,
pin_memory=True,
drop_last=True)
if rank == 0:
validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
base_mels_path=a.input_mels_dir)
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
sampler=None,
batch_size=1,
pin_memory=True,
drop_last=True)
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
generator.train()
combd.train()
sbd.train()
msd.train()
for epoch in range(max(0, last_epoch), a.training_epochs):
if rank == 0:
start = time.time()
print("Epoch: {}".format(epoch+1))
if h.num_gpus > 1:
train_sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader):
if rank == 0:
start_b = time.time()
x, y, length, _, y_mel = batch
x = torch.autograd.Variable(x.to(device, non_blocking=True))
y = torch.autograd.Variable(y.to(device, non_blocking=True))
length = torch.autograd.Variable(length.to(device, non_blocking=True))
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
y = y.unsqueeze(1)
#if use this, need to add some conv in generator, follow: https://github.com/ncsoft/avocodo/blob/2999557bbd040a6f3eb6f7006a317d89537b78cd/avocodo/models/generator.py#L109
#in this implement, only last conv is used
# ys = [
# pqmf_lv2.analysis(
# y
# )[:, :h.projection_filters[1]],
# pqmf_lv1.analysis(
# y
# )[:, :h.projection_filters[2]],
# y
# ]
ys = [y]
spec, phase = generator(x, length)
y_g_hat = stft.inverse(spec, phase)
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
h.fmin, h.fmax_for_loss)
ys_hat = [y_g_hat.detach()]
optim_d.zero_grad()
# combd
#ys list length 1 contain y shape batch x 1 x segment_size
#y_g_hat shape batch x 1 x segment_size
y_du_hat_r, y_du_hat_g, _, _ = combd(ys, ys_hat)
loss_disc_u, losses_disc_u_r, losses_disc_u_g = discriminator_loss(y_du_hat_r, y_du_hat_g)
# sbd
y_dp_hat_r, y_dp_hat_g, _, _ = sbd(y, y_g_hat.detach())
loss_disc_p, losses_disc_p_r, losses_disc_p_g = discriminator_loss(y_dp_hat_r, y_dp_hat_g)
# MSD
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_disc_all = loss_disc_s + loss_disc_p + loss_disc_u
loss_disc_all.backward()
optim_d.step()
# Generator
optim_g.zero_grad()
# L1 Mel-Spectrogram Loss
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
y_du_hat_r, y_du_hat_g, fmap_u_r, fmap_u_g = combd(ys, ys_hat)
y_dp_hat_r, y_dp_hat_g, fmap_p_r, fmap_p_g = sbd(y, y_g_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
loss_fm_u = feature_loss(fmap_u_r, fmap_u_g)
loss_fm_p = feature_loss(fmap_p_r, fmap_p_g)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_u, losses_gen_u = generator_loss(y_du_hat_g)
loss_gen_p, losses_gen_p = generator_loss(y_dp_hat_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_p + loss_gen_u + loss_fm_s + loss_fm_u + loss_fm_p + loss_mel
loss_gen_all.backward()
optim_g.step()
if rank == 0:
# STDOUT logging
if steps % a.stdout_interval == 0:
with torch.no_grad():
mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
format(steps, loss_gen_all, mel_error, time.time() - start_b))
# checkpointing
if steps % a.checkpoint_interval == 0 and steps != 0:
checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
save_checkpoint(checkpoint_path,
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
try:
os.remove("{}/do_{:08d}".format(a.checkpoint_path, steps-10000))
except:
print("continue_training")
save_checkpoint(checkpoint_path,
{'combd': (combd.module if h.num_gpus > 1
else combd).state_dict(),
'sbd': (sbd.module if h.num_gpus > 1
else sbd).state_dict(),
'msd': (msd.module if h.num_gpus > 1
else msd).state_dict(),
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
'epoch': epoch})
# Tensorboard summary logging
if steps % a.summary_interval == 0:
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
sw.add_scalar("training/mel_spec_error", mel_error, steps)
# Validation
if steps % a.validation_interval == 0: # and steps != 0:
generator.eval()
torch.cuda.empty_cache()
val_err_tot = 0
with torch.no_grad():
for j, batch in enumerate(validation_loader):
x, y, length, _, y_mel = batch
# y_g_hat = generator(x.to(device))
length = length.to(device)
length = torch.LongTensor([x.size(2)]).to(device)
spec, phase = generator(x.to(device), length)
y_g_hat = stft.inverse(spec, phase)
y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
h.hop_size, h.win_size,
h.fmin, h.fmax_for_loss)
val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
if j <= 4:
if steps == 0:
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
h.sampling_rate, h.hop_size, h.win_size,
h.fmin, h.fmax)
sw.add_figure('generated/y_hat_spec_{}'.format(j),
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
val_err = val_err_tot / (j+1)
sw.add_scalar("validation/mel_spec_error", val_err, steps)
generator.train()
steps += 1
scheduler_g.step()
scheduler_d.step()
if rank == 0:
print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
def main():
print('Initializing Training Process..')
parser = argparse.ArgumentParser()
parser.add_argument('--group_name', default=None)
parser.add_argument('--input_wavs_dir', default='LJSpeech-1.1/wavs')
parser.add_argument('--input_mels_dir', default='ft_dataset')
parser.add_argument('--input_training_file', default='LJSpeech-1.1/training.txt')
parser.add_argument('--input_validation_file', default='LJSpeech-1.1/validation.txt')
parser.add_argument('--checkpoint_path', default='lightvoc_checkpoint')
parser.add_argument('--config', default='')
parser.add_argument('--training_epochs', default=3000, type=int)
parser.add_argument('--stdout_interval', default=5, type=int)
parser.add_argument('--checkpoint_interval', default=10000, type=int)
parser.add_argument('--summary_interval', default=100, type=int)
parser.add_argument('--validation_interval', default=2000, type=int)
parser.add_argument('--fine_tuning', default=False, type=bool)
a = parser.parse_args()
with open(a.config) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
build_env(a.config, 'config.json', a.checkpoint_path)
torch.manual_seed(h.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(h.seed)
h.num_gpus = torch.cuda.device_count()
h.batch_size = int(h.batch_size / h.num_gpus)
print('Batch size per GPU :', h.batch_size)
else:
pass
if h.num_gpus > 1:
mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
else:
train(0, a, h)
if __name__ == '__main__':
main()