diff --git a/setup.py b/setup.py index ef6b2b2..5fabfc4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'voicebox-pytorch', packages = find_packages(exclude=[]), - version = '0.0.34', + version = '0.0.35', license='MIT', description = 'Voicebox - Pytorch', author = 'Phil Wang', @@ -21,6 +21,7 @@ 'beartype', 'einops>=0.6.1', 'lightning>=2.0.7', + 'spear-tts-pytorch>=0.3.4', 'torch>=2.0', 'torchdiffeq', 'torchode', diff --git a/voicebox_pytorch/voicebox_pytorch.py b/voicebox_pytorch/voicebox_pytorch.py index 1cd1ec1..1e80926 100644 --- a/voicebox_pytorch/voicebox_pytorch.py +++ b/voicebox_pytorch/voicebox_pytorch.py @@ -13,7 +13,7 @@ from torchdiffeq import odeint from beartype import beartype -from beartype.typing import Tuple, Optional +from beartype.typing import Tuple, Optional, List from einops.layers.torch import Rearrange from einops import rearrange, repeat, reduce, pack, unpack @@ -23,6 +23,7 @@ from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss, maximum_path from audiolm_pytorch import EncodecWrapper +from spear_tts_pytorch import TextToSemantic import torchaudio.transforms as T from torchaudio.functional import DB_to_amplitude @@ -30,6 +31,7 @@ from vocos import Vocos LOGGER = logging.getLogger(__file__) + # helper functions def exists(val): @@ -599,6 +601,7 @@ def forward( if should_align: alignment_hard, _, alignment_logprob, _ = self.forward_aligner(phoneme_emb, phoneme_mask, mel, mel_mask) target = alignment_hard + # combine audio, phoneme, conditioning embed = torch.cat((x, phoneme_emb, cond), dim = -1) @@ -640,10 +643,10 @@ class VoiceBox(Module): def __init__( self, *, - num_phoneme_tokens, + num_cond_tokens, audio_enc_dec: Optional[AudioEncoderDecoder] = None, dim_in = None, - dim_phoneme_emb = 1024, + dim_cond_emb = 1024, dim = 1024, depth = 24, dim_head = 64, @@ -675,13 +678,13 @@ def __init__( nn.SiLU() ) - self.null_phoneme_id = num_phoneme_tokens # use last phoneme token as null token for CFG - self.to_phoneme_emb = nn.Embedding(num_phoneme_tokens + 1, dim_phoneme_emb) + self.null_cond_id = num_cond_tokens # use last phoneme token as null token for CFG + self.to_cond_emb = nn.Embedding(num_cond_tokens + 1, dim_cond_emb) self.p_drop_prob = p_drop_prob self.frac_lengths_mask = frac_lengths_mask - self.to_embed = nn.Linear(dim_in * 2 + dim_phoneme_emb, dim) + self.to_embed = nn.Linear(dim_in * 2 + dim_cond_emb, dim) self.null_cond = nn.Parameter(torch.zeros(dim_in)) @@ -730,9 +733,9 @@ def forward( self, x, *, - phoneme_ids, cond, times, + cond_token_ids, cond_drop_prob = 0.1, target = None, mask = None, @@ -774,14 +777,24 @@ def forward( cond ) - phoneme_ids = torch.where( + cond_ids = torch.where( rearrange(cond_drop_mask, '... -> ... 1'), - self.null_phoneme_id, - phoneme_ids + self.null_cond_id, + cond_token_ids ) - phoneme_emb = self.to_phoneme_emb(phoneme_ids) - embed = torch.cat((x, phoneme_emb, cond), dim = -1) + cond_emb = self.to_cond_emb(cond_token_ids) + + # (todo) align conditioning embed if needed + # needed for spear-tts semantic <-> audio + + min_length = min([t.shape[-2] for t in (x, cond_emb, cond)]) + x, cond_emb, cond = tuple(t[..., :min_length, :] for t in (x, cond_emb, cond)) + + # concat source signal, semantic / phoneme conditioning embed, and conditioning + # and project + + embed = torch.cat((x, cond_emb, cond), dim = -1) x = self.to_embed(embed) x = self.conv_embed(x) + x @@ -799,6 +812,9 @@ def forward( if not exists(target): return x + target = target[..., :min_length, :] + mask = mask[..., :min_length] + if not exists(mask): return F.mse_loss(x, target) @@ -825,6 +841,7 @@ class ConditionalFlowMatcherWrapper(Module): def __init__( self, voicebox: VoiceBox, + text_to_semantic: Optional[TextToSemantic] = None, sigma = 0., ode_atol = 1e-5, ode_rtol = 1e-5, @@ -839,6 +856,8 @@ def __init__( self.voicebox = voicebox + self.text_to_semantic = text_to_semantic + self.cond_drop_prob = cond_drop_prob self.use_torchode = use_torchode @@ -859,8 +878,11 @@ def device(self): def sample( self, *, - phoneme_ids, cond, + texts: Optional[List[str]] = None, + text_token_ids: Optional[Tensor] = None, + semantic_token_ids = None, + phoneme_ids = None, mask = None, steps = 3, cond_scale = 1., @@ -879,6 +901,34 @@ def sample( self.voicebox.audio_enc_dec.eval() cond = self.voicebox.audio_enc_dec.encode(cond) + # setup text conditioning, either coming from duration model (as phoneme ids) + # for coming from text-to-semantic module from spear-tts paper, as (semantic ids) + # todo, DRY the conditioning logic, if sampling and training is not too different once everything is done + + assert sum([*map(exists, (texts, text_token_ids, semantic_token_ids, phoneme_ids))]) <= 1 + + using_text_to_semantic = exists(texts) or exists(text_token_ids) + + if using_text_to_semantic: + assert exists(self.text_to_semantic), 'TextToSemantic must be passed into the ConditionalFlowMatcherWrapper as `text_to_semantic` in order to train directly on text' + + if using_text_to_semantic: + semantic_token_ids = self.text_to_semantic.generate( + source = default(text_token_ids, texts), + source_type = 'text', + target_type = 'speech', + max_length = 10 + ) + + cond_token_ids = semantic_token_ids if using_text_to_semantic else phoneme_ids + + # todo (properly align for text to semantic) + + min_length = min([t.shape[-2] for t in (cond_token_ids, cond)]) + cond_token_ids, cond = tuple(t[..., :min_length, :] for t in (cond_token_ids, cond)) + + # neural ode + self.voicebox.eval() def fn(t, x, *, packed_shape = None): @@ -888,7 +938,7 @@ def fn(t, x, *, packed_shape = None): out = self.voicebox.forward_with_cond_scale( x, times = t, - phoneme_ids = phoneme_ids, + cond_token_ids = cond_token_ids, cond = cond, cond_scale = cond_scale ) @@ -942,8 +992,11 @@ def forward( self, x1, *, - phoneme_ids, cond, + texts: Optional[List[str]] = None, + text_token_ids: Optional[Tensor] = None, + semantic_token_ids = None, + phoneme_ids = None, mask = None ): """ @@ -968,6 +1021,28 @@ def forward( if cond_is_raw_audio: cond = self.voicebox.audio_enc_dec.encode(cond) + # setup text conditioning, either coming from duration model (as phoneme ids) + # for coming from text-to-semantic module from spear-tts paper, as (semantic ids) + + assert sum([*map(exists, (texts, text_token_ids, semantic_token_ids, phoneme_ids))]) <= 1 + + using_text_to_semantic = exists(texts) or exists(text_token_ids) + + if using_text_to_semantic: + assert exists(self.text_to_semantic), 'TextToSemantic must be passed into the ConditionalFlowMatcherWrapper as `text_to_semantic` in order to train directly on text' + + if using_text_to_semantic: + semantic_token_ids = self.text_to_semantic.generate( + source = default(text_token_ids, texts), + source_type = 'text', + target_type = 'speech', + max_length = 10 + ) + + cond_token_ids = semantic_token_ids if using_text_to_semantic else phoneme_ids + + # main conditional flow logic is below, in a mere 5 loc + # x0 is gaussian noise x0 = torch.randn_like(x1) @@ -989,11 +1064,11 @@ def forward( loss = self.voicebox( w, - phoneme_ids = phoneme_ids, cond = cond, mask = mask, times = times, target = flow, + cond_token_ids = cond_token_ids, cond_drop_prob = self.cond_drop_prob )