Skip to content

Commit

Permalink
deviate from paper, and start introducing spear tts text-to-semantic …
Browse files Browse the repository at this point in the history
…module as a way of aligning text to audio. the goal is to offer the old way (duration / pitch predictor) as well as text to semantic using spec decoding
  • Loading branch information
lucidrains committed Sep 23, 2023
1 parent 44a41e3 commit 0a61879
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 17 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.34',
version = '0.0.35',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand All @@ -21,6 +21,7 @@
'beartype',
'einops>=0.6.1',
'lightning>=2.0.7',
'spear-tts-pytorch>=0.3.4',
'torch>=2.0',
'torchdiffeq',
'torchode',
Expand Down
107 changes: 91 additions & 16 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchdiffeq import odeint

from beartype import beartype
from beartype.typing import Tuple, Optional
from beartype.typing import Tuple, Optional, List

from einops.layers.torch import Rearrange
from einops import rearrange, repeat, reduce, pack, unpack
Expand All @@ -23,13 +23,15 @@
from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss, maximum_path

from audiolm_pytorch import EncodecWrapper
from spear_tts_pytorch import TextToSemantic

import torchaudio.transforms as T
from torchaudio.functional import DB_to_amplitude

from vocos import Vocos

LOGGER = logging.getLogger(__file__)

# helper functions

def exists(val):
Expand Down Expand Up @@ -599,6 +601,7 @@ def forward(
if should_align:
alignment_hard, _, alignment_logprob, _ = self.forward_aligner(phoneme_emb, phoneme_mask, mel, mel_mask)
target = alignment_hard

# combine audio, phoneme, conditioning

embed = torch.cat((x, phoneme_emb, cond), dim = -1)
Expand Down Expand Up @@ -640,10 +643,10 @@ class VoiceBox(Module):
def __init__(
self,
*,
num_phoneme_tokens,
num_cond_tokens,
audio_enc_dec: Optional[AudioEncoderDecoder] = None,
dim_in = None,
dim_phoneme_emb = 1024,
dim_cond_emb = 1024,
dim = 1024,
depth = 24,
dim_head = 64,
Expand Down Expand Up @@ -675,13 +678,13 @@ def __init__(
nn.SiLU()
)

self.null_phoneme_id = num_phoneme_tokens # use last phoneme token as null token for CFG
self.to_phoneme_emb = nn.Embedding(num_phoneme_tokens + 1, dim_phoneme_emb)
self.null_cond_id = num_cond_tokens # use last phoneme token as null token for CFG
self.to_cond_emb = nn.Embedding(num_cond_tokens + 1, dim_cond_emb)

self.p_drop_prob = p_drop_prob
self.frac_lengths_mask = frac_lengths_mask

self.to_embed = nn.Linear(dim_in * 2 + dim_phoneme_emb, dim)
self.to_embed = nn.Linear(dim_in * 2 + dim_cond_emb, dim)

self.null_cond = nn.Parameter(torch.zeros(dim_in))

Expand Down Expand Up @@ -730,9 +733,9 @@ def forward(
self,
x,
*,
phoneme_ids,
cond,
times,
cond_token_ids,
cond_drop_prob = 0.1,
target = None,
mask = None,
Expand Down Expand Up @@ -774,14 +777,24 @@ def forward(
cond
)

phoneme_ids = torch.where(
cond_ids = torch.where(
rearrange(cond_drop_mask, '... -> ... 1'),
self.null_phoneme_id,
phoneme_ids
self.null_cond_id,
cond_token_ids
)

phoneme_emb = self.to_phoneme_emb(phoneme_ids)
embed = torch.cat((x, phoneme_emb, cond), dim = -1)
cond_emb = self.to_cond_emb(cond_token_ids)

# (todo) align conditioning embed if needed
# needed for spear-tts semantic <-> audio

min_length = min([t.shape[-2] for t in (x, cond_emb, cond)])
x, cond_emb, cond = tuple(t[..., :min_length, :] for t in (x, cond_emb, cond))

# concat source signal, semantic / phoneme conditioning embed, and conditioning
# and project

embed = torch.cat((x, cond_emb, cond), dim = -1)
x = self.to_embed(embed)

x = self.conv_embed(x) + x
Expand All @@ -799,6 +812,9 @@ def forward(
if not exists(target):
return x

target = target[..., :min_length, :]
mask = mask[..., :min_length]

if not exists(mask):
return F.mse_loss(x, target)

Expand All @@ -825,6 +841,7 @@ class ConditionalFlowMatcherWrapper(Module):
def __init__(
self,
voicebox: VoiceBox,
text_to_semantic: Optional[TextToSemantic] = None,
sigma = 0.,
ode_atol = 1e-5,
ode_rtol = 1e-5,
Expand All @@ -839,6 +856,8 @@ def __init__(

self.voicebox = voicebox

self.text_to_semantic = text_to_semantic

self.cond_drop_prob = cond_drop_prob

self.use_torchode = use_torchode
Expand All @@ -859,8 +878,11 @@ def device(self):
def sample(
self,
*,
phoneme_ids,
cond,
texts: Optional[List[str]] = None,
text_token_ids: Optional[Tensor] = None,
semantic_token_ids = None,
phoneme_ids = None,
mask = None,
steps = 3,
cond_scale = 1.,
Expand All @@ -879,6 +901,34 @@ def sample(
self.voicebox.audio_enc_dec.eval()
cond = self.voicebox.audio_enc_dec.encode(cond)

# setup text conditioning, either coming from duration model (as phoneme ids)
# for coming from text-to-semantic module from spear-tts paper, as (semantic ids)
# todo, DRY the conditioning logic, if sampling and training is not too different once everything is done

assert sum([*map(exists, (texts, text_token_ids, semantic_token_ids, phoneme_ids))]) <= 1

using_text_to_semantic = exists(texts) or exists(text_token_ids)

if using_text_to_semantic:
assert exists(self.text_to_semantic), 'TextToSemantic must be passed into the ConditionalFlowMatcherWrapper as `text_to_semantic` in order to train directly on text'

if using_text_to_semantic:
semantic_token_ids = self.text_to_semantic.generate(
source = default(text_token_ids, texts),
source_type = 'text',
target_type = 'speech',
max_length = 10
)

cond_token_ids = semantic_token_ids if using_text_to_semantic else phoneme_ids

# todo (properly align for text to semantic)

min_length = min([t.shape[-2] for t in (cond_token_ids, cond)])
cond_token_ids, cond = tuple(t[..., :min_length, :] for t in (cond_token_ids, cond))

# neural ode

self.voicebox.eval()

def fn(t, x, *, packed_shape = None):
Expand All @@ -888,7 +938,7 @@ def fn(t, x, *, packed_shape = None):
out = self.voicebox.forward_with_cond_scale(
x,
times = t,
phoneme_ids = phoneme_ids,
cond_token_ids = cond_token_ids,
cond = cond,
cond_scale = cond_scale
)
Expand Down Expand Up @@ -942,8 +992,11 @@ def forward(
self,
x1,
*,
phoneme_ids,
cond,
texts: Optional[List[str]] = None,
text_token_ids: Optional[Tensor] = None,
semantic_token_ids = None,
phoneme_ids = None,
mask = None
):
"""
Expand All @@ -968,6 +1021,28 @@ def forward(
if cond_is_raw_audio:
cond = self.voicebox.audio_enc_dec.encode(cond)

# setup text conditioning, either coming from duration model (as phoneme ids)
# for coming from text-to-semantic module from spear-tts paper, as (semantic ids)

assert sum([*map(exists, (texts, text_token_ids, semantic_token_ids, phoneme_ids))]) <= 1

using_text_to_semantic = exists(texts) or exists(text_token_ids)

if using_text_to_semantic:
assert exists(self.text_to_semantic), 'TextToSemantic must be passed into the ConditionalFlowMatcherWrapper as `text_to_semantic` in order to train directly on text'

if using_text_to_semantic:
semantic_token_ids = self.text_to_semantic.generate(
source = default(text_token_ids, texts),
source_type = 'text',
target_type = 'speech',
max_length = 10
)

cond_token_ids = semantic_token_ids if using_text_to_semantic else phoneme_ids

# main conditional flow logic is below, in a mere 5 loc

# x0 is gaussian noise

x0 = torch.randn_like(x1)
Expand All @@ -989,11 +1064,11 @@ def forward(

loss = self.voicebox(
w,
phoneme_ids = phoneme_ids,
cond = cond,
mask = mask,
times = times,
target = flow,
cond_token_ids = cond_token_ids,
cond_drop_prob = self.cond_drop_prob
)

Expand Down

0 comments on commit 0a61879

Please sign in to comment.