-
Notifications
You must be signed in to change notification settings - Fork 321
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ccd2dcc
commit 56d3b92
Showing
5 changed files
with
854 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
import datetime as dt | ||
import logging | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import soundfile as sf | ||
import torch | ||
from matcha.hifigan.config import v1 | ||
from matcha.hifigan.denoiser import Denoiser | ||
from matcha.hifigan.models import Generator as HiFiGAN | ||
from matcha.text import sequence_to_text, text_to_sequence | ||
from matcha.utils.utils import intersperse | ||
from tqdm.auto import tqdm | ||
from train import get_model, get_params | ||
|
||
from icefall.checkpoint import load_checkpoint | ||
from icefall.utils import AttributeDict | ||
|
||
|
||
def get_parser(): | ||
parser = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
) | ||
|
||
parser.add_argument( | ||
"--epoch", | ||
type=int, | ||
default=140, | ||
help="""It specifies the checkpoint to use for decoding. | ||
Note: Epoch counts from 1. | ||
""", | ||
) | ||
|
||
parser.add_argument( | ||
"--exp-dir", | ||
type=Path, | ||
default="matcha/exp", | ||
help="""The experiment dir. | ||
It specifies the directory where all training related | ||
files, e.g., checkpoints, log, etc, are saved | ||
""", | ||
) | ||
|
||
return parser | ||
|
||
|
||
def load_vocoder(checkpoint_path): | ||
h = AttributeDict(v1) | ||
hifigan = HiFiGAN(h).to("cpu") | ||
hifigan.load_state_dict( | ||
torch.load(checkpoint_path, map_location="cpu")["generator"] | ||
) | ||
_ = hifigan.eval() | ||
hifigan.remove_weight_norm() | ||
return hifigan | ||
|
||
|
||
def to_waveform(mel, vocoder, denoiser): | ||
audio = vocoder(mel).clamp(-1, 1) | ||
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() | ||
return audio.cpu().squeeze() | ||
|
||
|
||
def save_to_folder(filename: str, output: dict, folder: str): | ||
folder = Path(folder) | ||
folder.mkdir(exist_ok=True, parents=True) | ||
np.save(folder / f"{filename}", output["mel"].cpu().numpy()) | ||
sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24") | ||
|
||
|
||
def process_text(text: str): | ||
x = torch.tensor( | ||
intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0), | ||
dtype=torch.long, | ||
device="cpu", | ||
)[None] | ||
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") | ||
x_phones = sequence_to_text(x.squeeze(0).tolist()) | ||
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones} | ||
|
||
|
||
def synthesise(model, n_timesteps, text, length_scale, temperature, spks=None): | ||
text_processed = process_text(text) | ||
start_t = dt.datetime.now() | ||
output = model.synthesise( | ||
text_processed["x"], | ||
text_processed["x_lengths"], | ||
n_timesteps=n_timesteps, | ||
temperature=temperature, | ||
spks=spks, | ||
length_scale=length_scale, | ||
) | ||
print("output.shape", list(output.keys()), output["mel"].shape) | ||
# merge everything to one dict | ||
output.update({"start_t": start_t, **text_processed}) | ||
return output | ||
|
||
|
||
@torch.inference_mode() | ||
def main(): | ||
parser = get_parser() | ||
args = parser.parse_args() | ||
params = get_params() | ||
|
||
params.update(vars(args)) | ||
logging.info(params) | ||
|
||
logging.info("About to create model") | ||
model = get_model(params) | ||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) | ||
model.eval() | ||
|
||
vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1") | ||
denoiser = Denoiser(vocoder, mode="zeros") | ||
|
||
texts = [ | ||
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.", | ||
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.", | ||
] | ||
|
||
# Number of ODE Solver steps | ||
n_timesteps = 2 | ||
|
||
# Changes to the speaking rate | ||
length_scale = 1.0 | ||
|
||
# Sampling temperature | ||
temperature = 0.667 | ||
|
||
outputs, rtfs = [], [] | ||
rtfs_w = [] | ||
for i, text in enumerate(tqdm(texts)): | ||
output = synthesise( | ||
model=model, | ||
n_timesteps=n_timesteps, | ||
text=text, | ||
length_scale=length_scale, | ||
temperature=temperature, | ||
) # , torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0)) | ||
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) | ||
|
||
# Compute Real Time Factor (RTF) with HiFi-GAN | ||
t = (dt.datetime.now() - output["start_t"]).total_seconds() | ||
rtf_w = t * 22050 / (output["waveform"].shape[-1]) | ||
|
||
# Pretty print | ||
print(f"{'*' * 53}") | ||
print(f"Input text - {i}") | ||
print(f"{'-' * 53}") | ||
print(output["x_orig"]) | ||
print(f"{'*' * 53}") | ||
print(f"Phonetised text - {i}") | ||
print(f"{'-' * 53}") | ||
print(output["x_phones"]) | ||
print(f"{'*' * 53}") | ||
print(f"RTF:\t\t{output['rtf']:.6f}") | ||
print(f"RTF Waveform:\t{rtf_w:.6f}") | ||
rtfs.append(output["rtf"]) | ||
rtfs_w.append(rtf_w) | ||
|
||
# Save the generated waveform | ||
save_to_folder(i, output, folder="./my-output") | ||
|
||
print(f"Number of ODE steps: {n_timesteps}") | ||
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}") | ||
print( | ||
f"Mean RTF Waveform (incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
|
||
logging.basicConfig(format=formatter, level=logging.INFO) | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.