Skip to content

Commit

Permalink
First working version.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Oct 16, 2024
1 parent ccd2dcc commit 56d3b92
Show file tree
Hide file tree
Showing 5 changed files with 854 additions and 42 deletions.
178 changes: 178 additions & 0 deletions egs/ljspeech/TTS/matcha/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#!/usr/bin/env python3

import argparse
import datetime as dt
import logging
from pathlib import Path

import numpy as np
import soundfile as sf
import torch
from matcha.hifigan.config import v1
from matcha.hifigan.denoiser import Denoiser
from matcha.hifigan.models import Generator as HiFiGAN
from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.utils import intersperse
from tqdm.auto import tqdm
from train import get_model, get_params

from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--epoch",
type=int,
default=140,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)

parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)

return parser


def load_vocoder(checkpoint_path):
h = AttributeDict(v1)
hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan


def to_waveform(mel, vocoder, denoiser):
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.cpu().squeeze()


def save_to_folder(filename: str, output: dict, folder: str):
folder = Path(folder)
folder.mkdir(exist_ok=True, parents=True)
np.save(folder / f"{filename}", output["mel"].cpu().numpy())
sf.write(folder / f"{filename}.wav", output["waveform"], 22050, "PCM_24")


def process_text(text: str):
x = torch.tensor(
intersperse(text_to_sequence(text, ["english_cleaners2"])[0], 0),
dtype=torch.long,
device="cpu",
)[None]
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu")
x_phones = sequence_to_text(x.squeeze(0).tolist())
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}


def synthesise(model, n_timesteps, text, length_scale, temperature, spks=None):
text_processed = process_text(text)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=n_timesteps,
temperature=temperature,
spks=spks,
length_scale=length_scale,
)
print("output.shape", list(output.keys()), output["mel"].shape)
# merge everything to one dict
output.update({"start_t": start_t, **text_processed})
return output


@torch.inference_mode()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()

params.update(vars(args))
logging.info(params)

logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.eval()

vocoder = load_vocoder("/star-fj/fangjun/open-source/Matcha-TTS/generator_v1")
denoiser = Denoiser(vocoder, mode="zeros")

texts = [
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
"Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar.",
]

# Number of ODE Solver steps
n_timesteps = 2

# Changes to the speaking rate
length_scale = 1.0

# Sampling temperature
temperature = 0.667

outputs, rtfs = [], []
rtfs_w = []
for i, text in enumerate(tqdm(texts)):
output = synthesise(
model=model,
n_timesteps=n_timesteps,
text=text,
length_scale=length_scale,
temperature=temperature,
) # , torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)

# Compute Real Time Factor (RTF) with HiFi-GAN
t = (dt.datetime.now() - output["start_t"]).total_seconds()
rtf_w = t * 22050 / (output["waveform"].shape[-1])

# Pretty print
print(f"{'*' * 53}")
print(f"Input text - {i}")
print(f"{'-' * 53}")
print(output["x_orig"])
print(f"{'*' * 53}")
print(f"Phonetised text - {i}")
print(f"{'-' * 53}")
print(output["x_phones"])
print(f"{'*' * 53}")
print(f"RTF:\t\t{output['rtf']:.6f}")
print(f"RTF Waveform:\t{rtf_w:.6f}")
rtfs.append(output["rtf"])
rtfs_w.append(rtf_w)

# Save the generated waveform
save_to_folder(i, output, folder="./my-output")

print(f"Number of ODE steps: {n_timesteps}")
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
print(
f"Mean RTF Waveform (incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}"
)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)
main()
74 changes: 64 additions & 10 deletions egs/ljspeech/TTS/matcha/models/matcha_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

import matcha.utils.monotonic_align as monotonic_align

# from matcha import utils
# from matcha.models.baselightningmodule import BaseLightningClass
from matcha.models.components.flow_matching import CFM
Expand All @@ -30,7 +31,7 @@ def __init__(
encoder,
decoder,
cfm,
# data_statistics,
data_statistics,
out_size,
optimizer=None,
scheduler=None,
Expand Down Expand Up @@ -71,9 +72,13 @@ def __init__(
)

# self.update_data_statistics(data_statistics)
self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))

@torch.inference_mode()
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):
def synthesise(
self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0
):
"""
Generates mel-spectrogram from text. Returns:
1. encoder outputs
Expand Down Expand Up @@ -149,7 +154,17 @@ def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, leng
"rtf": rtf,
}

def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
def forward(
self,
x,
x_lengths,
y,
y_lengths,
spks=None,
out_size=None,
cond=None,
durations=None,
):
"""
Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
Expand Down Expand Up @@ -187,7 +202,9 @@ def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=Non
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
const = -0.5 * math.log(2 * math.pi) * self.n_feats
factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
factor = -0.5 * torch.ones(
mu_x.shape, dtype=mu_x.dtype, device=mu_x.device
)
y_square = torch.matmul(factor.transpose(1, 2), y**2)
y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
Expand All @@ -206,12 +223,25 @@ def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=Non
# - Do not need this hack for Matcha-TTS, but it works with it as well
if not isinstance(out_size, type(None)):
max_offset = (y_lengths - out_size).clamp(0)
offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
offset_ranges = list(
zip([0] * max_offset.shape[0], max_offset.cpu().numpy())
)
out_offset = torch.LongTensor(
[torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges]
[
torch.tensor(random.choice(range(start, end)) if end > start else 0)
for start, end in offset_ranges
]
).to(y_lengths)
attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device)
y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device)
attn_cut = torch.zeros(
attn.shape[0],
attn.shape[1],
out_size,
dtype=attn.dtype,
device=attn.device,
)
y_cut = torch.zeros(
y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device
)

y_cut_lengths = []
for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
Expand All @@ -233,12 +263,36 @@ def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=Non
mu_y = mu_y.transpose(1, 2)

# Compute loss of the decoder
diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
diff_loss, _ = self.decoder.compute_loss(
x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond
)

if self.prior_loss:
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = torch.sum(
0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask
)
prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
else:
prior_loss = 0

return dur_loss, prior_loss, diff_loss, attn

def get_losses(self, batch):
x, x_lengths = batch["x"], batch["x_lengths"]
y, y_lengths = batch["y"], batch["y_lengths"]
spks = batch["spks"]

dur_loss, prior_loss, diff_loss, *_ = self(
x=x,
x_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
spks=spks,
out_size=self.out_size,
durations=batch["durations"],
)
return {
"dur_loss": dur_loss,
"prior_loss": prior_loss,
"diff_loss": diff_loss,
}
Loading

0 comments on commit 56d3b92

Please sign in to comment.