Skip to content

Commit

Permalink
added VITS recipe
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Oct 21, 2024
1 parent e0136d9 commit 2a5aa7c
Show file tree
Hide file tree
Showing 19 changed files with 1,774 additions and 0 deletions.
3 changes: 3 additions & 0 deletions egs/libritts/TTS/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Downloading x-vector"

git clone https://huggingface.co/datasets/zrjin/xvector_nnet_1a_libritts_clean_460 $dl_dir/xvector_nnet_1a_libritts_clean_460

mkdir -p exp/xvector_nnet_1a/
cp -r $dl_dir/xvector_nnet_1a_libritts_clean_460/* exp/xvector_nnet_1a/
fi

fi
Expand Down
1 change: 1 addition & 0 deletions egs/libritts/TTS/vits/duration_predictor.py
1 change: 1 addition & 0 deletions egs/libritts/TTS/vits/flow.py
1 change: 1 addition & 0 deletions egs/libritts/TTS/vits/generator.py
1 change: 1 addition & 0 deletions egs/libritts/TTS/vits/hifigan.py
273 changes: 273 additions & 0 deletions egs/libritts/TTS/vits/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script performs model inference on test set.
Usage:
./vits/infer.py \
--epoch 1000 \
--exp-dir ./vits/exp \
--max-duration 500
"""


import argparse
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List

import k2
import torch
import torch.nn as nn
import torchaudio
from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import LibrittsTtsDataModule

from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--epoch",
type=int,
default=1000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)

parser.add_argument(
"--exp-dir",
type=str,
default="vits/exp",
help="The experiment dir",
)

parser.add_argument(
"--tokens",
type=str,
default="data/tokens.txt",
help="""Path to vocabulary.""",
)

return parser


def infer_dataset(
dl: torch.utils.data.DataLoader,
subset: str,
params: AttributeDict,
model: nn.Module,
tokenizer: Tokenizer,
speaker_map: Dict[str, int],
) -> None:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
tokenizer:
Used to convert text to phonemes.
"""

# Background worker save audios to disk.
def _save_worker(
subset: str,
batch_size: int,
cut_ids: List[str],
audio: torch.Tensor,
audio_pred: torch.Tensor,
audio_lens: List[int],
audio_lens_pred: List[int],
):
for i in range(batch_size):
torchaudio.save(
str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"),
audio[i : i + 1, : audio_lens[i]],
sample_rate=params.sampling_rate,
)
torchaudio.save(
str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"),
audio_pred[i : i + 1, : audio_lens_pred[i]],
sample_rate=params.sampling_rate,
)

device = next(model.parameters()).device
num_cuts = 0
log_interval = 5

try:
num_batches = len(dl)
except TypeError:
num_batches = "?"

futures = []
with ThreadPoolExecutor(max_workers=1) as executor:
for batch_idx, batch in enumerate(dl):
batch_size = len(batch["tokens"])

tokens = batch["tokens"]
tokens = tokenizer.tokens_to_token_ids(
tokens, intersperse_blank=True, add_sos=True, add_eos=True
)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
tokens = tokens.to(device)
tokens_lens = tokens_lens.to(device)
# tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
speakers = (
torch.Tensor([speaker_map[sid] for sid in batch["speakers"]])
.int()
.to(device)
)

audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]

audio_pred, _, durations = model.inference_batch(
text=tokens,
text_lengths=tokens_lens,
sids=speakers,
)
audio_pred = audio_pred.detach().cpu()
# convert to samples
audio_lens_pred = (
(durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
)

futures.append(
executor.submit(
_save_worker,
subset,
batch_size,
cut_ids,
audio,
audio_pred,
audio_lens,
audio_lens_pred,
)
)

num_cuts += batch_size

if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"

logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
# return results
for f in futures:
f.result()


@torch.no_grad()
def main():
parser = get_parser()
LibrittsTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)

params = get_params()
params.update(vars(args))

params.suffix = f"epoch-{params.epoch}"

params.res_dir = params.exp_dir / "infer" / params.suffix
params.save_wav_dir = params.res_dir / "wav"
params.save_wav_dir.mkdir(parents=True, exist_ok=True)

setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
logging.info("Infer started")

device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)

tokenizer = Tokenizer(params.tokens)
params.blank_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size

# we need cut ids to display recognition results.
args.return_cuts = True
libritts = LibrittsTtsDataModule(args)
speaker_map = libritts.speakers()
params.num_spks = len(speaker_map)

logging.info(f"Device: {device}")
logging.info(params)

logging.info("About to create model")
model = get_model(params)

load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)

model.to(device)
model.eval()

num_param_g = sum([p.numel() for p in model.generator.parameters()])
logging.info(f"Number of parameters in generator: {num_param_g}")
num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
logging.info(f"Number of parameters in discriminator: {num_param_d}")
logging.info(f"Total number of parameters: {num_param_g + num_param_d}")

test_cuts = libritts.test_cuts()
test_dl = libritts.test_dataloaders(test_cuts)

valid_cuts = libritts.valid_cuts()
valid_dl = libritts.valid_dataloaders(valid_cuts)

infer_sets = {"test": test_dl, "valid": valid_dl}

for subset, dl in infer_sets.items():
save_wav_dir = params.res_dir / "wav" / subset
save_wav_dir.mkdir(parents=True, exist_ok=True)

logging.info(f"Processing {subset} set, saving to {save_wav_dir}")

infer_dataset(
dl=dl,
subset=subset,
params=params,
model=model,
tokenizer=tokenizer,
speaker_map=speaker_map,
)

logging.info(f"Wav files are saved to {params.save_wav_dir}")
logging.info("Done!")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions egs/libritts/TTS/vits/loss.py
1 change: 1 addition & 0 deletions egs/libritts/TTS/vits/monotonic_align
1 change: 1 addition & 0 deletions egs/libritts/TTS/vits/posterior_encoder.py
1 change: 1 addition & 0 deletions egs/libritts/TTS/vits/residual_coupling.py
Loading

0 comments on commit 2a5aa7c

Please sign in to comment.