Skip to content
This repository was archived by the owner on Oct 6, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/python/piper_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

from .vits.lightning import VitsModel

Expand All @@ -24,6 +24,11 @@ def main():
type=int,
help="Save checkpoint every N epochs (default: 1)",
)
parser.add_argument(
"--patience",
type=int,
help="Number of validation cycles to allow to pass without improvement before stopping training"
)
parser.add_argument(
"--quality",
default="medium",
Expand Down Expand Up @@ -57,12 +62,15 @@ def main():
num_speakers = int(config["num_speakers"])
sample_rate = int(config["audio"]["sample_rate"])

trainer = Trainer.from_argparse_args(args)
callbacks = []
if args.checkpoint_epochs is not None:
trainer.callbacks = [ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)]
callbacks.append(ModelCheckpoint(every_n_epochs=args.checkpoint_epochs, monitor="val_loss", save_top_k=1, mode="min"))
_LOGGER.debug(
"Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs
)
if args.patience is not None:
callbacks.append(EarlyStopping(monitor="val_loss", min_delta=0.00, patience=args.patience, verbose=True, mode="min"))
trainer = Trainer.from_argparse_args(args, callbacks=callbacks)

dict_args = vars(args)
if args.quality == "x-low":
Expand Down
45 changes: 24 additions & 21 deletions src/python/piper_train/vits/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,29 +282,32 @@ def training_step_d(self, batch: Batch):
def validation_step(self, batch: Batch, batch_idx: int):
val_loss = self.training_step_g(batch) + self.training_step_d(batch)
self.log("val_loss", val_loss)

# Generate audio examples
for utt_idx, test_utt in enumerate(self._test_dataset):
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
scales = [0.667, 1.0, 0.8]
sid = (
test_utt.speaker_id.to(self.device)
if test_utt.speaker_id is not None
else None
)
test_audio = self(text, text_lengths, scales, sid=sid).detach()

# Scale to make louder in [-1, 1]
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max())))

tag = test_utt.text or str(utt_idx)
self.logger.experiment.add_audio(
tag, test_audio, sample_rate=self.hparams.sample_rate
)

return val_loss

def on_validation_end(self) -> None:
# Generate audio examples after validation, but not during sanity check
if not self.trainer.sanity_checking:
for utt_idx, test_utt in enumerate(self._test_dataset):
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
scales = [0.667, 1.0, 0.8]
sid = (
test_utt.speaker_id.to(self.device)
if test_utt.speaker_id is not None
else None
)
test_audio = self(text, text_lengths, scales, sid=sid).detach()

# Scale to make louder in [-1, 1]
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio).max()))

tag = test_utt.text or str(utt_idx)
self.logger.experiment.add_audio(
tag, test_audio, sample_rate=self.hparams.sample_rate
)

return super().on_validation_end()

def configure_optimizers(self):
optimizers = [
torch.optim.AdamW(
Expand Down
67 changes: 35 additions & 32 deletions src/python/piper_train/vits/mel_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,24 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
mode="reflect",
)
y = y.squeeze(1)

spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
with torch.autocast(device_type=y.device.type, dtype=torch.float32):
y = y.to(y.device.type, torch.float32)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

return spec

Expand Down Expand Up @@ -116,24 +117,26 @@ def mel_spectrogram_torch(
mode="reflect",
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
with torch.autocast(device_type=y.device.type, dtype=torch.float32):
y = y.to(y.device.type, torch.float32)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
)
# print(y.dtype, spec.dtype)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)

return spec