diff --git a/docs/README.md b/docs/README.md index bf1588b..2300e15 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,6 +8,7 @@ - **Multilingual phonemization** — 30+ phonemizer backends for dozens of languages - **Multi-engine support** — load voices from Piper, Mimic3, Coqui, Transformers, and native phoonnx format - **Voice manager** — download and cache models from HuggingFace and other sources +- **Phoneme alignment** — optional per-phoneme timing for visemes, lip-sync, and karaoke - **Training pipeline** — preprocess datasets and train new VITS voices (`phoonnx_train`) - **OVOS plugin** — drop-in TTS plugin for the OpenVoiceOS / Mycroft ecosystem @@ -43,6 +44,7 @@ with wave.open("output.wav", "wb") as wav_file: - [Installation](installation.md) - [Usage Guide](usage.md) +- [Phoneme Alignment](alignment.md) - [Voice Manager](voice_manager.md) - [Phonemizers](phonemizers.md) - [Configuration Reference](configuration.md) diff --git a/docs/alignment.md b/docs/alignment.md new file mode 100644 index 0000000..c3c731d --- /dev/null +++ b/docs/alignment.md @@ -0,0 +1,155 @@ +# Phoneme Alignment + +Phoneme alignment gives you per-phoneme timing: how many audio samples each phoneme +occupies in the synthesized output. This is the foundation for visemes (lip-sync), +karaoke-style word highlighting, and subtitle generation. + +> **Model support required.** Alignment is an optional second output of the ONNX +> model. Standard exported models do **not** include it. See +> [Exporting a model with alignment support](#exporting-a-model-with-alignment-support) +> below. + +--- + +## Getting alignments from `synthesize()` + +Pass `include_alignments=True` to `TTSVoice.synthesize()`: + +```python +from phoonnx.voice import TTSVoice + +voice = TTSVoice.load("model-aligned.onnx") + +for chunk in voice.synthesize("Hello world.", include_alignments=True): + print(f"phonemes : {chunk.phonemes}") + print(f"phoneme ids: {chunk.phoneme_ids}") + + if chunk.phoneme_alignments: + for align in chunk.phoneme_alignments: + duration_ms = align.num_samples / chunk.sample_rate * 1000 + print(f" {align.phoneme!r:6s} {duration_ms:6.1f} ms") + else: + # Model does not expose alignment output, or reconstruction failed + print(" (no alignment available)") +``` + +### `AudioChunk` alignment fields + +| Field | Type | Description | +|---|---|---| +| `phonemes` | `list[str]` | Phoneme tokens for this sentence | +| `phoneme_ids` | `list[int]` | Integer IDs passed to the ONNX model | +| `phoneme_id_samples` | `np.ndarray \| None` | Raw sample counts per phoneme ID (from model) | +| `phoneme_alignments` | `list[PhonemeAlignment] \| None` | Reconstructed per-phoneme timings | + +`phoneme_alignments` is `None` when: +- `include_alignments=False` (default) +- The model has only one output (does not support alignment) +- The alignment reconstruction fails (ID sequence mismatch) + +`phonemes` and `phoneme_ids` are always populated regardless of `include_alignments`. + +### `PhonemeAlignment` + +```python +@dataclass +class PhonemeAlignment: + phoneme: str # the phoneme token, e.g. "h", "ɛ", "l" + num_samples: int # number of PCM samples occupied by this phoneme +``` + +Convert `num_samples` to milliseconds: `num_samples / chunk.sample_rate * 1000`. + +--- + +## Lower-level: `phoneme_ids_to_audio()` + +If you are working at the phoneme-ID level you can call the method directly: + +```python +audio_or_tuple = voice.phoneme_ids_to_audio(phoneme_ids, include_alignments=True) + +if isinstance(audio_or_tuple, tuple): + audio, phoneme_id_samples = audio_or_tuple + # phoneme_id_samples[i] = samples for phoneme_ids[i], or None if unsupported +else: + audio = audio_or_tuple # include_alignments=False +``` + +--- + +## `hop_length` + +The raw model output is a duration in frames; phoonnx converts frames to samples +using `hop_length` (default **256**, matching the standard VITS vocoder hop size). + +Override it in the voice config JSON: + +```json +{ + "hop_length": 256 +} +``` + +Or via `VoiceConfig`: + +```python +voice.config.hop_length = 256 +``` + +--- + +## Exporting a model with alignment support + +Standard VITS models expose only the audio tensor. To expose the phoneme-duration +tensor as a second output, use the `--add-phoneme-alignment` flag when exporting: + +```bash +phoonnx-train export-onnx checkpoint.ckpt -c config.json --add-phoneme-alignment +``` + +This modifies the exported `.onnx` graph to surface the `Ceil` node output (phoneme +durations) as a named model output. The modification is done by +`add_phoneme_alignment_output()` in `phoonnx_train/export_onnx.py`. + +You can also apply it post-hoc to an already-exported model: + +```python +from phoonnx_train.export_onnx import add_phoneme_alignment_output +from pathlib import Path + +add_phoneme_alignment_output( + model_path=Path("model.onnx"), + output_path=Path("model-aligned.onnx"), # omit to overwrite in place + tensor_name="autodetect", # or pass the tensor name explicitly +) +``` + +> **Compatibility note.** Adding the alignment output may break third-party +> frameworks (e.g. Piper) that expect a single output tensor. Keep a separate +> copy of the model for standard TTS use. + +--- + +## Use cases + +### Visemes / lip-sync + +Map each phoneme to a viseme index, then schedule face-rig blend-shape changes +at the sample offset accumulated from `num_samples`: + +```python +VISEME = {"p": 0, "b": 0, "m": 0, "f": 1, "v": 1, ...} # your mapping + +offset = 0 +for align in chunk.phoneme_alignments or []: + t = offset / chunk.sample_rate + viseme = VISEME.get(align.phoneme, -1) + schedule_viseme(t, viseme) + offset += align.num_samples +``` + +### Karaoke / subtitle highlighting + +Accumulate sample offsets to get word-start timestamps, then use them to +synchronise text highlights with audio playback. diff --git a/phoonnx/config.py b/phoonnx/config.py index dfd90da..b1f974e 100644 --- a/phoonnx/config.py +++ b/phoonnx/config.py @@ -10,6 +10,7 @@ DEFAULT_NOISE_SCALE = 0.667 DEFAULT_LENGTH_SCALE = 1.0 DEFAULT_NOISE_W_SCALE = 0.8 +DEFAULT_HOP_LENGTH = 256 class Engine(str, Enum): @@ -121,6 +122,8 @@ class VoiceConfig: noise_w_scale: float = DEFAULT_NOISE_W_SCALE add_diacritics: bool = None # arabic and hebrew + hop_length: int = DEFAULT_HOP_LENGTH + # tokenization settings tokenizer: Optional[TTSTokenizer] = None blank_at_start: bool = True @@ -383,6 +386,7 @@ def from_dict(config: dict[str, Any], # phoonnx/piper/coqui/mimic3 length_scale=inference.get("length_scale", DEFAULT_LENGTH_SCALE), noise_w_scale=inference.get("noise_w", DEFAULT_NOISE_W_SCALE), add_diacritics=diacritics, + hop_length=config.get("hop_length", DEFAULT_HOP_LENGTH), lang_code=lang_code, alphabet=Alphabet(alphabet) if isinstance(alphabet, str) else alphabet, engine=Engine(engine) if isinstance(engine, str) else engine, diff --git a/phoonnx/phoneme_ids.py b/phoonnx/phoneme_ids.py new file mode 100644 index 0000000..5f8489c --- /dev/null +++ b/phoonnx/phoneme_ids.py @@ -0,0 +1,457 @@ +"""Utilities for converting phonemes to ids.""" + +from collections.abc import Sequence +from enum import Enum +from typing import Optional, TextIO, Dict, List, Union, Set, Mapping + +try: + from ovos_utils.log import LOG +except ImportError: + import logging + + LOG = logging.getLogger(__name__) + LOG.setLevel("DEBUG") + +PHONEME_ID_LIST = List[int] +PHONEME_ID_MAP = Dict[str, int] +PHONEME_LIST = List[str] +PHONEME_WORD_LIST = List[PHONEME_LIST] + +DEFAULT_IPA_PHONEME_ID_MAP: Dict[str, PHONEME_ID_LIST] = { + "_": [0], + "^": [1], + "$": [2], + " ": [3], + "!": [4], + "'": [5], + "(": [6], + ")": [7], + ",": [8], + "-": [9], + ".": [10], + ":": [11], + ";": [12], + "?": [13], + "a": [14], + "b": [15], + "c": [16], + "d": [17], + "e": [18], + "f": [19], + "h": [20], + "i": [21], + "j": [22], + "k": [23], + "l": [24], + "m": [25], + "n": [26], + "o": [27], + "p": [28], + "q": [29], + "r": [30], + "s": [31], + "t": [32], + "u": [33], + "v": [34], + "w": [35], + "x": [36], + "y": [37], + "z": [38], + "æ": [39], + "ç": [40], + "ð": [41], + "ø": [42], + "ħ": [43], + "ŋ": [44], + "œ": [45], + "ǀ": [46], + "ǁ": [47], + "ǂ": [48], + "ǃ": [49], + "ɐ": [50], + "ɑ": [51], + "ɒ": [52], + "ɓ": [53], + "ɔ": [54], + "ɕ": [55], + "ɖ": [56], + "ɗ": [57], + "ɘ": [58], + "ə": [59], + "ɚ": [60], + "ɛ": [61], + "ɜ": [62], + "ɞ": [63], + "ɟ": [64], + "ɠ": [65], + "ɡ": [66], + "ɢ": [67], + "ɣ": [68], + "ɤ": [69], + "ɥ": [70], + "ɦ": [71], + "ɧ": [72], + "ɨ": [73], + "ɪ": [74], + "ɫ": [75], + "ɬ": [76], + "ɭ": [77], + "ɮ": [78], + "ɯ": [79], + "ɰ": [80], + "ɱ": [81], + "ɲ": [82], + "ɳ": [83], + "ɴ": [84], + "ɵ": [85], + "ɶ": [86], + "ɸ": [87], + "ɹ": [88], + "ɺ": [89], + "ɻ": [90], + "ɽ": [91], + "ɾ": [92], + "ʀ": [93], + "ʁ": [94], + "ʂ": [95], + "ʃ": [96], + "ʄ": [97], + "ʈ": [98], + "ʉ": [99], + "ʊ": [100], + "ʋ": [101], + "ʌ": [102], + "ʍ": [103], + "ʎ": [104], + "ʏ": [105], + "ʐ": [106], + "ʑ": [107], + "ʒ": [108], + "ʔ": [109], + "ʕ": [110], + "ʘ": [111], + "ʙ": [112], + "ʛ": [113], + "ʜ": [114], + "ʝ": [115], + "ʟ": [116], + "ʡ": [117], + "ʢ": [118], + "ʲ": [119], + "ˈ": [120], + "ˌ": [121], + "ː": [122], + "ˑ": [123], + "˞": [124], + "β": [125], + "θ": [126], + "χ": [127], + "ᵻ": [128], + "ⱱ": [129], + "0": [130], + "1": [131], + "2": [132], + "3": [133], + "4": [134], + "5": [135], + "6": [136], + "7": [137], + "8": [138], + "9": [139], + "̧": [140], + "̃": [141], + "̪": [142], + "̯": [143], + "̩": [144], + "ʰ": [145], + "ˤ": [146], + "ε": [147], + "↓": [148], + "#": [149], + '"': [150], + "↑": [151], + "̺": [152], + "̻": [153], + "g": [154], + "ʦ": [155], + "X": [156], + "̝": [157], + "̊": [158], + "ɝ": [159], + "ʷ": [160], +} + +DEFAULT_PAD_TOKEN = DEFAULT_BLANK_TOKEN = "_" # padding (0) +DEFAULT_BOS_TOKEN = "^" # beginning of sentence +DEFAULT_EOS_TOKEN = "$" # end of sentence +DEFAULT_BLANK_WORD_TOKEN = " " # padding between words + +STRESS: Set[str] = {"ˈ", "ˌ"} +"""Default stress characters""" + +# TODO - consider keeping ? to help intonation with questions +# in some languages (eg. Portuguese) this is the +# only difference between a statement and a question +# and it is important for tone +PUNCTUATION_MAP: Mapping[str, str] = {";": ",", ":": ",", "?": ".", "!": "."} +"""Default punctuation simplification into short (,) and long (.) pauses""" + + +class BlankBetween(str, Enum): + """Placement of blank tokens""" + + TOKENS = "tokens" + """Blank between every token/phoneme""" + + WORDS = "words" + """Blank between every word""" + + TOKENS_AND_WORDS = "tokens_and_words" + """Blank between every token/phoneme and every word (may be different symbols)""" + + +def phonemes_to_ids( + phonemes: PHONEME_LIST, + id_map: Optional[Mapping[str, Union[int, Sequence[int]]]] = None, + blank_token: Optional[str] = DEFAULT_BLANK_TOKEN, + bos_token: Optional[str] = DEFAULT_BOS_TOKEN, + eos_token: Optional[str] = DEFAULT_EOS_TOKEN, + word_sep_token: Optional[str] = DEFAULT_BLANK_WORD_TOKEN, + include_whitespace: Optional[bool] = True, + blank_at_start: bool = True, + blank_at_end: bool = True, + blank_between: BlankBetween = BlankBetween.TOKENS_AND_WORDS, +) -> PHONEME_ID_LIST: + """Phonemes to ids.""" + if not phonemes: + return [] + if not id_map: + id_map = DEFAULT_IPA_PHONEME_ID_MAP + + # compat with piper-style mapping that uses lists + id_map = {k: v if isinstance(v, list) else [v] + for k, v in id_map.items()} + + ids: list[int] = [] + blank_id = blank_token if isinstance(blank_token, int) \ + else id_map.get(blank_token, [len(id_map)]) if blank_token \ + else [len(id_map)] + eos_id = eos_token if isinstance(eos_token, int) \ + else id_map.get(eos_token, [len(id_map)]) if eos_token \ + else [len(id_map)] + bos_id = eos_token if isinstance(bos_token, int) \ + else id_map.get(bos_token, [len(id_map)]) if bos_token \ + else [len(id_map)] + + if bos_token is not None: + ids.extend(bos_id) + if blank_token is not None and blank_at_start: + ids.extend(blank_id) + + blank_between_tokns = (blank_token is not None and + blank_between in [BlankBetween.TOKENS, BlankBetween.TOKENS_AND_WORDS]) + blank_between_words = (blank_token is not None and + blank_between in [BlankBetween.WORDS, BlankBetween.TOKENS_AND_WORDS]) + + # first pre-process phoneme_map to check for dipthongs having their own phoneme_id + # common in mimic3 models + compound_phonemes = sorted((k for k in id_map if len(k) > 1), key=len, reverse=True) + i = 0 + while i < len(phonemes): + matched = False + + # Try to match compound phonemes starting at index i + for compound in compound_phonemes: + n = len(compound) + joined = ''.join(phonemes[i:i + n]) + if joined == compound: + ids.extend(id_map[compound]) + if blank_between_tokns and i + n < len(phonemes): + ids.extend(blank_id) + i += n + matched = True + break + + if matched: + continue + + phoneme = phonemes[i] + if phoneme not in id_map: + if phoneme == " " and not include_whitespace: + i += 1 + continue + LOG.warning("Missing phoneme from id map: %s", phoneme) + i += 1 + continue + + if phoneme == " ": + if include_whitespace: + ids.extend(id_map[phoneme]) + if blank_between_tokns: + ids.extend(blank_id) + elif blank_between_words: + ids.extend(id_map[word_sep_token]) + if blank_between_tokns: + ids.extend(blank_id) + else: + ids.extend(id_map[phoneme]) + if blank_between_tokns and i < len(phonemes) - 1: + ids.extend(blank_id) + i += 1 + + if blank_token is not None and blank_at_end: + if not include_whitespace and word_sep_token and blank_between_words: + if blank_between_tokns: + ids.extend(blank_id) + ids.extend(id_map[word_sep_token]) + if blank_between_tokns: + ids.extend(blank_id) + else: + ids.extend(blank_id) + if eos_token is not None: + ids.extend(eos_id) + + return ids + +def load_phoneme_ids(phonemes_file: TextIO) -> PHONEME_ID_MAP: + """ + Load phoneme id mapping from a text file. + Format is IDPHONEME + Comments start with # + + Args: + phonemes_file: text file + + Returns: + dict with phoneme -> id + """ + phoneme_to_id = {} + for line in phonemes_file: + line = line.strip("\r\n") + if (not line) or line.startswith("#") or (" " not in line): + # Exclude blank lines, comments, or malformed lines + continue + + if line.strip().isdigit(): # phoneme is whitespace + phoneme_str = " " + phoneme_id = int(line) + else: + phoneme_id, phoneme_str = line.split(" ", maxsplit=1) + if phoneme_str.isdigit(): + phoneme_id, phoneme_str = phoneme_str, phoneme_id + + phoneme_to_id[phoneme_str] = int(phoneme_id) + + return phoneme_to_id + + +def load_phoneme_map(phoneme_map_file: TextIO) -> Dict[str, List[str]]: + """ + Load phoneme/phoneme mapping from a text file. + Format is FROM_PHONEMETO_PHONEME[TO_PHONEME...] + Comments start with # + + Args: + phoneme_map_file: text file + + Returns: + dict with from_phoneme -> [to_phoneme, to_phoneme, ...] + """ + phoneme_map = {} + for line in phoneme_map_file: + line = line.strip("\r\n") + if (not line) or line.startswith("#") or (" " not in line): + # Exclude blank lines, comments, or malformed lines + continue + + from_phoneme, to_phonemes_str = line.split(" ", maxsplit=1) + if not to_phonemes_str.strip(): + # To whitespace + phoneme_map[from_phoneme] = [" "] + else: + # To one or more non-whitespace phonemes + phoneme_map[from_phoneme] = to_phonemes_str.split() + + return phoneme_map + + +if __name__ == "__main__": + phoneme_ids_path = "/home/miro/PycharmProjects/phoonnx_tts/mimic3_ap/phonemes.txt" + with open(phoneme_ids_path, "r", encoding="utf-8") as ids_file: + phoneme_to_id = load_phoneme_ids(ids_file) + print(phoneme_to_id) + + phoneme_map_path = "/home/miro/PycharmProjects/phoonnx_tts/mimic3_ap/phoneme_map.txt" + with open(phoneme_map_path, "r", encoding="utf-8") as map_file: + phoneme_map = load_phoneme_map(map_file) + # print(phoneme_map) + + from phoonnx.phonemizers import EspeakPhonemizer + + # test original mimic3 code + from phonemes2ids import phonemes2ids as mimic3_phonemes2ids + + # test original piper code + from piper.phoneme_ids import phonemes_to_ids as piper_phonemes_to_ids + + espeak = EspeakPhonemizer() + phone_str: str = espeak.phonemize_string("hello world", "en") + + phones: PHONEME_LIST = list(phone_str) + phone_words: PHONEME_WORD_LIST = [list(w) for w in phone_str.split()] + print(phone_str) + print(phones) # piper style + print(phone_words) # mimic3 style + + mapping = {k: v[0] for k, v in DEFAULT_IPA_PHONEME_ID_MAP.items()} + print("\n#### piper (tokens_and_words + include_whitespace)") + print("reference", piper_phonemes_to_ids(phones)) + print("phonnx ", phonemes_to_ids(phones, + id_map=mapping, include_whitespace=True)) + + print("\n#### mimic3 (words)") + print("reference", mimic3_phonemes2ids(phone_words, + mapping, + bos=DEFAULT_BOS_TOKEN, + eos=DEFAULT_EOS_TOKEN, + blank=DEFAULT_PAD_TOKEN, + blank_at_end=True, + blank_at_start=True, + blank_word=DEFAULT_BLANK_WORD_TOKEN, + blank_between=BlankBetween.WORDS, + auto_bos_eos=True)) + print("phonnx ", phonemes_to_ids(phones, + id_map=mapping, + include_whitespace=False, + blank_between=BlankBetween.WORDS)) + + print("\n#### mimic3 (tokens)") + print("reference", mimic3_phonemes2ids(phone_words, + mapping, + bos=DEFAULT_BOS_TOKEN, + eos=DEFAULT_EOS_TOKEN, + blank=DEFAULT_PAD_TOKEN, + blank_at_end=True, + blank_at_start=True, + blank_word=DEFAULT_BLANK_WORD_TOKEN, + blank_between=BlankBetween.TOKENS, + auto_bos_eos=True)) + print("phonnx ", phonemes_to_ids(phones, + id_map=mapping, + include_whitespace=False, + blank_between=BlankBetween.TOKENS)) + print("\n#### mimic3 (tokens_and_words)") + print("reference", mimic3_phonemes2ids(phone_words, + mapping, + bos=DEFAULT_BOS_TOKEN, + eos=DEFAULT_EOS_TOKEN, + blank=DEFAULT_PAD_TOKEN, + blank_at_end=True, + blank_at_start=True, + blank_word=DEFAULT_BLANK_WORD_TOKEN, + blank_between=BlankBetween.TOKENS_AND_WORDS, + auto_bos_eos=True)) + print("phonnx ", phonemes_to_ids(phones, + id_map=mapping, + include_whitespace=False, + blank_between=BlankBetween.TOKENS_AND_WORDS)) diff --git a/phoonnx/voice.py b/phoonnx/voice.py index 1d56654..84b32d1 100644 --- a/phoonnx/voice.py +++ b/phoonnx/voice.py @@ -4,7 +4,7 @@ import wave from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Iterable, Optional, Union, Dict +from typing import Any, Iterable, Optional, Union, Dict, Tuple, List import numpy as np import onnxruntime @@ -52,6 +52,12 @@ def apply(self, text: str) -> str: return text +@dataclass +class PhonemeAlignment: + phoneme: str + num_samples: int + + @dataclass class AudioChunk: """Chunk of raw audio.""" @@ -68,8 +74,23 @@ class AudioChunk: audio_float_array: np.ndarray """Audio data as float numpy array in [-1, 1].""" + phonemes: list[str] + """Phonemes that produced this audio chunk.""" + + phoneme_ids: list[int] + """Phoneme ids that produced this audio chunk.""" + + phoneme_id_samples: Optional[np.ndarray] = None + """Number of audio samples for each phoneme id (alignments).""" + + phoneme_alignments: Optional[list[PhonemeAlignment]] = None + """Alignments between phonemes and audio samples.""" + + # --- + _audio_int16_array: Optional[np.ndarray] = None _audio_int16_bytes: Optional[bytes] = None + _phoneme_alignments: Optional[list[PhonemeAlignment]] = None _MAX_WAV_VALUE: float = 32767.0 @property @@ -260,18 +281,14 @@ def synthesize( self, text: str, syn_config: Optional[SynthesisConfig] = None, + include_alignments: bool = False, ) -> Iterable[AudioChunk]: """ - Synthesize speech from input text, yielding one AudioChunk per sentence. - - Generates sentence-level audio by phonemizing the input text and synthesizing each sentence into a float32 PCM audio array in the range [-1.0, 1.0]. If enabled in the synthesis configuration, user-provided phonetic spellings and diacritic augmentation are applied before phonemization. Output audio may be normalized and volume-scaled according to the configuration; samples are clipped to [-1.0, 1.0]. - - Parameters: - text (str): The input text to synthesize. - syn_config (Optional[SynthesisConfig]): Optional synthesis options (e.g., enable_phonetic_spellings, add_diacritics, normalize_audio, volume). If omitted, a default SynthesisConfig is used. - - Returns: - Iterable[AudioChunk]: An iterable that yields one AudioChunk per synthesized sentence. Each AudioChunk contains a float32 audio array, sample rate taken from the voice config, 2-byte sample width, and 1 channel. + Synthesize one audio chunk per sentence from text. + + :param text: Text to synthesize. + :param syn_config: Synthesis configuration. + :param include_alignments: If True and the model supports it, include phoneme/audio alignments. """ if syn_config is None: syn_config = SynthesisConfig() @@ -289,15 +306,31 @@ def synthesize( # All phonemization goes through the unified self.phonemize method sentence_phonemes = self.phonemize(text) LOG.debug("phonemes=%s", sentence_phonemes) - all_phoneme_ids_for_synthesis = [ - self.phonemes_to_ids(phonemes) for phonemes in sentence_phonemes if phonemes - ] - for phoneme_ids in all_phoneme_ids_for_synthesis: - if not phoneme_ids: + # Ids for special tokens — used by the alignment reconstruction below. + # These live in the tokenizer, not in VoiceConfig, so read them once here. + _tok = self.config.tokenizer + _vocab = _tok.vocabulary + _idx2char: dict = _vocab.idx2char + _blank_id: int = _tok.blank_id + _bos_id: int = _vocab.bos_id + _eos_id: int = _vocab.eos_id + + for phonemes in sentence_phonemes: + if not phonemes: continue + phoneme_ids = self.phonemes_to_ids(phonemes) - audio = self.phoneme_ids_to_audio(phoneme_ids, syn_config) + phoneme_id_samples: Optional[np.ndarray] = None + audio_result = self.phoneme_ids_to_audio( + phoneme_ids, syn_config, include_alignments=include_alignments + ) + if isinstance(audio_result, tuple): + # Audio + alignments + audio, phoneme_id_samples = audio_result + else: + # Audio only + audio = audio_result if syn_config.normalize_audio: max_val = np.max(np.abs(audio)) @@ -312,11 +345,40 @@ def synthesize( audio = np.clip(audio, -1.0, 1.0).astype(np.float32) + phoneme_alignments: Optional[list[PhonemeAlignment]] = None + if phoneme_id_samples is not None and len(phoneme_id_samples) == len(phoneme_ids): + # Walk phoneme_ids + per-id sample counts together. + # Blank/bos durations are absorbed forward into the next real phoneme. + # Eos duration is absorbed backward into the last real phoneme. + # This works regardless of the tokenizer's blank_at_start/end settings. + alignments: List[PhonemeAlignment] = [] + pending: int = 0 + for pid, n_samples in zip(phoneme_ids, phoneme_id_samples.tolist()): + if pid == _bos_id or pid == _blank_id: + pending += n_samples + elif pid == _eos_id: + if alignments: + prev = alignments[-1] + alignments[-1] = PhonemeAlignment(prev.phoneme, prev.num_samples + n_samples) + else: + token = _idx2char.get(pid, "?") + alignments.append(PhonemeAlignment(token, pending + n_samples)) + pending = 0 + # Absorb any leftover pending (e.g. leading blank with no following phoneme) + if pending and alignments: + prev = alignments[-1] + alignments[-1] = PhonemeAlignment(prev.phoneme, prev.num_samples + pending) + phoneme_alignments = alignments or None + yield AudioChunk( sample_rate=self.config.sample_rate, sample_width=2, sample_channels=1, audio_float_array=audio, + phonemes=phonemes, + phoneme_ids=phoneme_ids, + phoneme_id_samples=phoneme_id_samples, + phoneme_alignments=phoneme_alignments, ) def synthesize_wav( @@ -357,14 +419,21 @@ def synthesize_wav( wav_file.writeframes(audio_chunk.audio_int16_bytes) def phoneme_ids_to_audio( - self, phoneme_ids: list[int], syn_config: Optional[SynthesisConfig] = None - ) -> np.ndarray: + self, phoneme_ids: list[int], + syn_config: Optional[SynthesisConfig] = None, + include_alignments: bool = False + ) -> Union[np.ndarray, Tuple[np.ndarray, Optional[np.ndarray]]]: """ Synthesize raw audio from phoneme ids. :param phoneme_ids: List of phoneme ids. :param syn_config: Synthesis configuration. + :param include_alignments: Return samples per phoneme id if True. :return: Audio float numpy array from voice model (unnormalized, in range [-1, 1]). + + If include_alignments is True and the voice model supports it, the return + value will be a tuple instead with (audio, phoneme_id_samples) where + phoneme_id_samples contains the number of audio samples per phoneme id. """ syn_config = syn_config or SynthesisConfig() @@ -412,8 +481,22 @@ def phoneme_ids_to_audio( # Keep only ONNX model-expected inputs args = {k: v for k, v in args.items() if k in expected_args} - audio = self.session.run(None, args)[0].squeeze() - return audio + # Synthesize through onnx + result = self.session.run(None, args) + audio = result[0].squeeze() + if not include_alignments: + return audio + + if len(result) == 1: + # Alignment is not available from voice model + return audio, None + + # Number of samples for each phoneme id + phoneme_id_samples = (result[1].squeeze() * self.config.hop_length).astype( + np.int64 + ) + + return audio, phoneme_id_samples diff --git a/phoonnx_train/export_onnx.py b/phoonnx_train/export_onnx.py index 63cc811..1194040 100644 --- a/phoonnx_train/export_onnx.py +++ b/phoonnx_train/export_onnx.py @@ -1,14 +1,14 @@ #!/usr/bin/env python3 -import click -import logging import json +import logging import os from pathlib import Path -from typing import Optional, Dict, Any, Tuple +from typing import Optional, Dict, Any, Tuple, Set +import click import torch + from phoonnx_train.vits.lightning import VitsModel -from phoonnx.version import VERSION_STR # Basic logging configuration logging.basicConfig(level=logging.DEBUG) @@ -131,12 +131,97 @@ def convert_to_piper(config_path: Path, output_path: Path = Path("piper.json")) json.dump(piper_config, f, indent=4, ensure_ascii=False) +def add_phoneme_alignment_output(model_path: Path, output_path: Optional[Path] = None, tensor_name: str = "autodetect") -> None: + """ + Adds a tensor representing phoneme durations to the ONNX model's graph outputs. + + This might cause compatibility issues with other frameworks (eg. piper) + + Args: + model_path: Path to the input ONNX model file. + output_path: Path where the modified ONNX model will be saved. + If empty, it defaults to the input path. + tensor_name: The name of the tensor to mark as an output. + If "autodetect", it looks for a unique output of a 'Ceil' node. + + Returns: + None: The function saves the modified model to `output_path`. + """ + import onnx + + # Use model_path for default output_path (overwrite) + output_path = output_path or model_path + + # Load the ONNX model + try: + model: onnx.ModelProto = onnx.load(model_path) + except Exception as e: + _LOGGER.fatal(f"Failed to load ONNX model from {model_path}: {e}") + return + + ceil_tensor_name: str + if tensor_name != "autodetect": + ceil_tensor_name = tensor_name + else: + ceil_tensor_names: Set[str] = set() + for node in model.graph.node: + # Check for nodes with the operation type "Ceil" + if node.op_type != "Ceil": + continue + + # Add all output tensor names of the 'Ceil' node + ceil_tensor_names.update(node.output) + + if not ceil_tensor_names: + _LOGGER.fatal(f"No ceil tensors detected in {model_path}. Use --tensor-name manually.") + return + + if len(ceil_tensor_names) > 1: + # Format the set of names nicely for the error message + names_str = ', '.join(sorted(list(ceil_tensor_names))) + _LOGGER.fatal( + f"Multiple ceil tensors detected in {model_path}. Use --tensor-name manually: {names_str}" + ) + return + + # Get the single detected tensor name + ceil_tensor_name = next(iter(ceil_tensor_names)) + + _LOGGER.info(f"Detected tensor name: {ceil_tensor_name}") + + # Check if the tensor is already an output of the graph + if any(output.name == ceil_tensor_name for output in model.graph.output): + _LOGGER.fatal( + f"Tensor '{ceil_tensor_name}' is already marked as output. Aborting." + ) + return + + # Create a new ValueInfoProto for the output + ceil_value_info = onnx.helper.ValueInfoProto() + + # Set the name of the output tensor + ceil_value_info.name = ceil_tensor_name + + # Append the new output to the graph's output list + model.graph.output.append(ceil_value_info) + + # Save the modified model + try: + onnx.save(model, output_path) + except Exception as e: + _LOGGER.fatal(f"Failed to save modified ONNX model to {output_path}: {e}") + return + + _LOGGER.info(f"Successfully wrote modified model with new output '{ceil_tensor_name}' to {output_path}") + + + # --- Main Logic using Click --- @click.command(help="Export a VITS model checkpoint to ONNX format.") @click.argument( "checkpoint", type=click.Path(exists=True, path_type=Path), - # help="Path to the PyTorch checkpoint file (*.ckpt)." + # help="Path to the PyTorch checkpoint file (*.ckpt)." ) @click.option( "-c", @@ -148,9 +233,15 @@ def convert_to_piper(config_path: Path, output_path: Path = Path("piper.json")) "-o", "--output-dir", type=click.Path(path_type=Path), - default=Path(os.getcwd()), # Set default to current working directory + default=Path(os.getcwd()), # Set default to current working directory help="Output directory for the ONNX model. (Default: current directory)" ) +@click.option( + "-a", + "--add-phoneme-alignment", + is_flag=True, + help="Add a phoneme alignment output tensor to the onnx model. (might cause compatibility issues with 3rd party frameworks)" +) @click.option( "-t", "--generate-tokens", @@ -167,6 +258,7 @@ def cli( checkpoint: Path, config: Path, output_dir: Path, + add_phoneme_alignment: bool, generate_tokens: bool, piper: bool, ) -> None: @@ -200,7 +292,6 @@ def cli( _LOGGER.error(f"Error loading config file {config}: {e}") return - alphabet: str = model_config.get("alphabet", "") phoneme_type: str = model_config.get("phoneme_type", "") phonemizer_model: str = model_config.get("phonemizer_model", "") # depends on phonemizer (eg. byt5) @@ -248,7 +339,8 @@ def cli( # ------------------------------------------------------------------------- # Define ONNX-compatible forward function - def infer_forward(text: torch.Tensor, text_lengths: torch.Tensor, scales: torch.Tensor, sid: Optional[torch.Tensor] = None) -> torch.Tensor: + def infer_forward(text: torch.Tensor, text_lengths: torch.Tensor, + scales: torch.Tensor, sid: Optional[torch.Tensor] = None) -> torch.Tensor: """ Custom forward pass for ONNX export, simplifying the input scales and returning only the audio tensor with shape [B, 1, T]. @@ -352,10 +444,15 @@ def infer_forward(text: torch.Tensor, text_lengths: torch.Tensor, scales: torch. except Exception as e: _LOGGER.error(f"Failed to add metadata to exported model {model_output}: {e}") + if add_phoneme_alignment: + try: + add_phoneme_alignment_output(model_output) + except Exception as e: + _LOGGER.error(f"Failed to add phoneme_alignment output to exported model {model_output}: {e}") _LOGGER.info("Export complete.") # ----------------------------------------------------------------------------- if __name__ == "__main__": - cli() \ No newline at end of file + cli() diff --git a/tests/test_alignment.py b/tests/test_alignment.py new file mode 100644 index 0000000..38138fc --- /dev/null +++ b/tests/test_alignment.py @@ -0,0 +1,504 @@ +"""Tests for phoneme alignment (PhonemeAlignment, AudioChunk, TTSVoice alignment logic). + +Unit tests are hermetic — no real models, no network. +Integration tests require the downloaded eu-ES dii espeak voice and a patched +alignment ONNX model at ALIGNED_MODEL (set up by setUpClass). + +Run unit tests only: + pytest tests/test_alignment.py -k "not Integration" + +Run everything (requires network on first run to download the voice): + pytest tests/test_alignment.py +""" +import os +import shutil +import tempfile +import unittest +from dataclasses import fields, is_dataclass +from unittest.mock import MagicMock + +import numpy as np + +from phoonnx.config import VoiceConfig, DEFAULT_HOP_LENGTH, PhonemeType +from phoonnx.voice import AudioChunk, PhonemeAlignment, TTSVoice + + +# --------------------------------------------------------------------------- +# Test fixtures / helpers +# --------------------------------------------------------------------------- + +class _Input: + """Minimal stand-in for an onnxruntime model input.""" + def __init__(self, name): + self.name = name + + +def _fake_audio(n=200): + return np.linspace(-0.5, 0.5, n, dtype=np.float32) + + +def _make_vocab(char2idx=None): + """Build a minimal Vocabulary mock with the standard special tokens.""" + char2idx = char2idx or {"_": 0, "^": 1, "$": 2, "h": 3, "e": 4, "l": 5, "o": 6} + idx2char = {v: k for k, v in char2idx.items()} + vocab = MagicMock() + vocab.char2idx = char2idx + vocab.idx2char = idx2char + vocab.blank_id = 0 + vocab.bos_id = 1 + vocab.eos_id = 2 + vocab.pad_id = 0 + vocab.blank = "_" + vocab.bos = "^" + vocab.eos = "$" + vocab.pad = "_" + return vocab + + +def _make_tokenizer(char2idx=None): + """Build a minimal TTSTokenizer mock.""" + vocab = _make_vocab(char2idx) + tok = MagicMock() + tok.vocabulary = vocab + tok.blank_id = 0 + tok.pad_id = 0 + tok.add_blank_char = True + tok.use_eos_bos = True + tok.blank_at_start = True + tok.blank_at_end = True + return tok + + +def _make_voice(session_outputs, phonemize_result=None, token_ids=None, char2idx=None): + """Build a TTSVoice with a mocked ONNX session. + + ``voice.phonemize`` and ``voice.phonemes_to_ids`` are mocked at the + instance level to bypass all text-processing paths. + + session_outputs: list of numpy arrays the fake session.run() returns. + token_ids: the full sequence including bos/blank/eos that the tokenizer + would normally produce (matches what phoneme_id_samples covers). + """ + if phonemize_result is None: + phonemize_result = [["h", "e", "l", "l", "o"]] + if token_ids is None: + # Simple sequence: [bos, blank, h, blank, e, blank, l, blank, l, blank, o, blank, eos] + token_ids = [1, 0, 3, 0, 4, 0, 5, 0, 5, 0, 6, 0, 2] + + session = MagicMock() + session.get_inputs.return_value = [_Input("input"), _Input("input_lengths")] + session.run.return_value = session_outputs + + cfg = MagicMock(spec=VoiceConfig) + cfg.lang_code = "en-US" + cfg.sample_rate = 22050 + cfg.num_speakers = 1 + cfg.phoneme_type = PhonemeType.UNICODE + cfg.noise_scale = 0.667 + cfg.length_scale = 1.0 + cfg.noise_w_scale = 0.8 + cfg.hop_length = DEFAULT_HOP_LENGTH + cfg.enable_phonetic_spellings = False + cfg.add_diacritics = False + cfg.tokenizer = _make_tokenizer(char2idx) + + voice = TTSVoice.__new__(TTSVoice) + voice.session = session + voice.config = cfg + voice.phonemizer = MagicMock() + voice.phonetic_spellings = None + voice.phonemize = MagicMock(return_value=phonemize_result) + voice.phonemes_to_ids = MagicMock(return_value=token_ids) + return voice + + +# --------------------------------------------------------------------------- +# PhonemeAlignment dataclass +# --------------------------------------------------------------------------- + +class TestPhonemeAlignment(unittest.TestCase): + def test_is_dataclass(self): + self.assertTrue(is_dataclass(PhonemeAlignment)) + + def test_fields(self): + pa = PhonemeAlignment(phoneme="a", num_samples=512) + self.assertEqual(pa.phoneme, "a") + self.assertEqual(pa.num_samples, 512) + + def test_equality(self): + self.assertEqual(PhonemeAlignment("b", 100), PhonemeAlignment("b", 100)) + self.assertNotEqual(PhonemeAlignment("b", 100), PhonemeAlignment("b", 200)) + + +# --------------------------------------------------------------------------- +# AudioChunk new fields +# --------------------------------------------------------------------------- + +class TestAudioChunkFields(unittest.TestCase): + def _chunk(self, **kw): + defaults = dict( + sample_rate=22050, sample_width=2, sample_channels=1, + audio_float_array=_fake_audio(), phonemes=[], phoneme_ids=[], + ) + defaults.update(kw) + return AudioChunk(**defaults) + + def test_phonemes_and_ids_stored(self): + c = self._chunk(phonemes=["h", "i"], phoneme_ids=[3, 4]) + self.assertEqual(c.phonemes, ["h", "i"]) + self.assertEqual(c.phoneme_ids, [3, 4]) + + def test_optional_fields_default_none(self): + c = self._chunk() + self.assertIsNone(c.phoneme_id_samples) + self.assertIsNone(c.phoneme_alignments) + + def test_alignment_objects_stored(self): + aligns = [PhonemeAlignment("h", 256), PhonemeAlignment("i", 512)] + c = self._chunk(phoneme_alignments=aligns) + self.assertEqual(c.phoneme_alignments[0].phoneme, "h") + + def test_id_samples_array_stored(self): + arr = np.array([100, 200, 300], dtype=np.int64) + c = self._chunk(phoneme_id_samples=arr) + np.testing.assert_array_equal(c.phoneme_id_samples, arr) + + +# --------------------------------------------------------------------------- +# VoiceConfig hop_length +# --------------------------------------------------------------------------- + +class TestHopLength(unittest.TestCase): + def test_default_hop_length_constant(self): + self.assertEqual(DEFAULT_HOP_LENGTH, 256) + + def test_hop_length_field_default(self): + hop_field = next(f for f in fields(VoiceConfig) if f.name == "hop_length") + self.assertEqual(hop_field.default, DEFAULT_HOP_LENGTH) + + def _phoonnx_cfg(self, **extra): + cfg = { + "phoonnx_version": "0.0.1", + "phoneme_type": "espeak", + "alphabet": "ipa", + "audio": {"sample_rate": 22050}, + "inference": {}, + "phoneme_id_map": {"_": 0, "^": 1, "$": 2}, + } + cfg.update(extra) + return cfg + + def test_from_dict_parses_hop_length(self): + cfg = VoiceConfig.from_dict(self._phoonnx_cfg(hop_length=512)) + self.assertEqual(cfg.hop_length, 512) + + def test_from_dict_hop_length_default(self): + cfg = VoiceConfig.from_dict(self._phoonnx_cfg()) + self.assertEqual(cfg.hop_length, DEFAULT_HOP_LENGTH) + + +# --------------------------------------------------------------------------- +# phoneme_ids_to_audio — return-value shape +# --------------------------------------------------------------------------- + +class TestPhonemeIdsToAudio(unittest.TestCase): + def test_without_alignments_returns_ndarray(self): + voice = _make_voice([_fake_audio()[np.newaxis, np.newaxis, :]]) + result = voice.phoneme_ids_to_audio([3, 4, 5], include_alignments=False) + self.assertIsInstance(result, np.ndarray) + + def test_with_alignments_single_output_returns_tuple_none(self): + voice = _make_voice([_fake_audio()[np.newaxis, np.newaxis, :]]) + audio, samples = voice.phoneme_ids_to_audio([1, 0, 3, 0, 2], include_alignments=True) + self.assertIsInstance(audio, np.ndarray) + self.assertIsNone(samples) + + def test_with_alignments_two_outputs_returns_tuple_array(self): + ids = [1, 0, 3, 0, 4, 0, 2] + audio_arr = _fake_audio()[np.newaxis, np.newaxis, :] + dur_arr = np.ones((1, len(ids)), dtype=np.float32) + voice = _make_voice([audio_arr, dur_arr]) + audio, samples = voice.phoneme_ids_to_audio(ids, include_alignments=True) + self.assertIsInstance(samples, np.ndarray) + self.assertEqual(samples.dtype, np.int64) + self.assertEqual(len(samples), len(ids)) + + def test_hop_length_scales_duration(self): + """Model durations (in frames) are multiplied by hop_length to get samples.""" + ids = [1, 0, 3, 0, 4, 0, 2] + raw = np.array([[2.0, 1.0, 3.0, 1.0, 4.0, 1.0, 2.0]], dtype=np.float32) + voice = _make_voice([_fake_audio()[np.newaxis, np.newaxis, :], raw]) + voice.config.hop_length = 100 + _, samples = voice.phoneme_ids_to_audio(ids, include_alignments=True) + expected = np.array([200, 100, 300, 100, 400, 100, 200], dtype=np.int64) + np.testing.assert_array_equal(samples, expected) + + +# --------------------------------------------------------------------------- +# synthesize — alignment reconstruction +# --------------------------------------------------------------------------- + +class TestAlignmentReconstruction(unittest.TestCase): + """Tests for the id→phoneme walk in synthesize(include_alignments=True). + + The new implementation uses tokenizer.vocabulary.idx2char to look up each + id and folds blank/bos/eos durations into adjacent real phonemes. + """ + + def _synth(self, token_ids, durations, phonemes=None, char2idx=None): + """Run synthesize and return the single AudioChunk.""" + if phonemes is None: + phonemes = ["h", "e", "l"] + audio = _fake_audio() + dur_arr = np.array([durations], dtype=np.float32) + voice = _make_voice( + [audio[np.newaxis, np.newaxis, :], dur_arr], + phonemize_result=[phonemes], + token_ids=token_ids, + char2idx=char2idx, + ) + voice.config.hop_length = 1 # frame == sample for easy arithmetic + chunks = list(voice.synthesize("x", include_alignments=True)) + self.assertEqual(len(chunks), 1) + return chunks[0] + + def test_real_phonemes_extracted(self): + # ids: [bos=1, blank=0, h=3, blank=0, e=4, blank=0, l=5, blank=0, eos=2] + ids = [1, 0, 3, 0, 4, 0, 5, 0, 2] + durs = [10, 5, 20, 5, 30, 5, 40, 5, 10] + chunk = self._synth(ids, durs, phonemes=["h", "e", "l"]) + self.assertIsNotNone(chunk.phoneme_alignments) + names = [a.phoneme for a in chunk.phoneme_alignments] + self.assertEqual(names, ["h", "e", "l"]) + + def test_blank_before_phoneme_absorbed_forward(self): + # ids: bos(10), blank(5), h(20), trailing-blank(5), eos(10) + # bos + blank_before_h → pending=15; h appended with 15+20=35 + # trailing blank → pending=5; eos → absorbed into h: 35+10=45 + # after loop: trailing pending 5 → absorbed into h: 45+5=50 + ids = [1, 0, 3, 0, 2] + durs = [10, 5, 20, 5, 10] + chunk = self._synth(ids, durs, phonemes=["h"]) + self.assertEqual(chunk.phoneme_alignments[0].num_samples, 50) + + def test_eos_absorbed_into_last_phoneme(self): + # ids: bos, blank, h, blank, eos(10) + ids = [1, 0, 3, 0, 2] + durs = [0, 0, 20, 5, 10] + chunk = self._synth(ids, durs, phonemes=["h"]) + # h = 0+0 (bos+blank_bos) + 20 + 5 (trailing blank) + 10 (eos) = 35 + # bos=0, blank=0 → pending=0; h=20 → PhonemeAlignment("h",0+20)=20; blank=5 → pending=5; eos=10 → absorbed into last → 20+10=30; then trailing pending 5 absorbed → 35 + # Actually: bos=0 pending=0, blank_bos=0 pending=0, h=20 → align(h,20), blank=5 → pending=5, eos=10 → absorbed into last: h.num_samples=20+10=30, then pending=5 absorbed into last: 30+5=35 + self.assertEqual(chunk.phoneme_alignments[0].num_samples, 35) + + def test_total_samples_equal_sum_of_durations(self): + ids = [1, 0, 3, 0, 4, 0, 5, 0, 2] + durs = [10, 5, 20, 5, 30, 5, 40, 5, 10] + chunk = self._synth(ids, durs, phonemes=["h", "e", "l"]) + total_aligned = sum(a.num_samples for a in chunk.phoneme_alignments) + self.assertEqual(total_aligned, sum(durs)) + + def test_length_mismatch_alignments_none(self): + """phoneme_id_samples length != phoneme_ids length → alignments stay None.""" + # 3 duration values but 9 ids + audio = _fake_audio() + ids = [1, 0, 3, 0, 4, 0, 5, 0, 2] + voice = _make_voice( + [audio[np.newaxis, np.newaxis, :], + np.array([[10.0, 20.0, 30.0]], dtype=np.float32)], + phonemize_result=[["h", "e", "l"]], + token_ids=ids, + ) + voice.config.hop_length = 1 + chunk = list(voice.synthesize("x", include_alignments=True))[0] + self.assertIsNone(chunk.phoneme_alignments) + + def test_unknown_id_becomes_question_mark(self): + """An id not in idx2char maps to '?' but doesn't crash or fail.""" + char2idx = {"_": 0, "^": 1, "$": 2, "h": 3} + ids = [1, 0, 3, 0, 99, 0, 2] # id 99 unknown + durs = [5, 5, 20, 5, 20, 5, 5] + chunk = self._synth(ids, durs, phonemes=["h", "?"], char2idx=char2idx) + # reconstruction should succeed; unknown id maps to "?" + self.assertIsNotNone(chunk.phoneme_alignments) + names = [a.phoneme for a in chunk.phoneme_alignments] + self.assertIn("?", names) + + def test_no_alignments_requested_fields_none(self): + ids = [1, 0, 3, 0, 4, 0, 2] + audio = _fake_audio() + dur = np.ones((1, len(ids)), dtype=np.float32) + voice = _make_voice([audio[np.newaxis, np.newaxis, :], dur], + phonemize_result=[["h", "e"]], token_ids=ids) + chunk = list(voice.synthesize("hi", include_alignments=False))[0] + self.assertIsNone(chunk.phoneme_id_samples) + self.assertIsNone(chunk.phoneme_alignments) + + def test_phonemes_and_ids_always_populated(self): + ids = [1, 0, 3, 0, 4, 0, 2] + audio = _fake_audio() + voice = _make_voice([audio[np.newaxis, np.newaxis, :]], + phonemize_result=[["h", "e"]], token_ids=ids) + chunk = list(voice.synthesize("hi", include_alignments=False))[0] + self.assertEqual(chunk.phonemes, ["h", "e"]) + self.assertEqual(chunk.phoneme_ids, ids) + + def test_empty_phonemes_yields_no_chunk(self): + audio = _fake_audio() + voice = _make_voice([audio[np.newaxis, np.newaxis, :]], + phonemize_result=[[]], token_ids=[]) + chunks = list(voice.synthesize("", include_alignments=True)) + self.assertEqual(chunks, []) + + def test_model_without_alignment_output_gives_none(self): + ids = [1, 0, 3, 0, 4, 0, 2] + audio = _fake_audio() + voice = _make_voice([audio[np.newaxis, np.newaxis, :]], + phonemize_result=[["h", "e"]], token_ids=ids) + chunk = list(voice.synthesize("hi", include_alignments=True))[0] + self.assertIsNone(chunk.phoneme_id_samples) + self.assertIsNone(chunk.phoneme_alignments) + + +# --------------------------------------------------------------------------- +# Integration tests — real model, real synthesis +# --------------------------------------------------------------------------- + +VOICE_ID = "OpenVoiceOS/phoonnx_eu-ES_dii_espeak" +ALIGNED_MODEL = "/tmp/phoonnx_test_aligned_eu-ES.onnx" +VOICE_CONFIG = None # set in setUpClass + + +def _ensure_model(): + """Download voice + patch alignment output; return (onnx_path, config_path).""" + from phoonnx.model_manager import TTSModelManager + m = TTSModelManager() + m.merge_default_voices() + v = m.voices.get(VOICE_ID) + if v is None: + raise unittest.SkipTest(f"Voice {VOICE_ID} not in registry") + v.download_config() + v.download_model() + import glob + voice_path = v.voice_path + onnx_files = glob.glob(os.path.join(voice_path, "*.onnx")) + json_files = glob.glob(os.path.join(voice_path, "*.json")) + if not onnx_files or not json_files: + raise unittest.SkipTest("Voice files not downloaded") + return onnx_files[0], json_files[0] + + +def _patch_alignment(src_onnx, dst_onnx): + """Add the Ceil alignment output to an ONNX model.""" + import onnx + model = onnx.load(src_onnx) + ceil_names = {o for node in model.graph.node + if node.op_type == "Ceil" for o in node.output} + if not ceil_names: + raise unittest.SkipTest("No Ceil node found in model — alignment not patchable") + if len(ceil_names) > 1: + raise unittest.SkipTest("Multiple Ceil nodes — cannot autodetect") + # Check not already an output + existing = {o.name for o in model.graph.output} + name = next(iter(ceil_names)) + if name not in existing: + vinfo = onnx.helper.ValueInfoProto() + vinfo.name = name + model.graph.output.append(vinfo) + onnx.save(model, dst_onnx) + else: + shutil.copy(src_onnx, dst_onnx) + + +class TestAlignmentIntegration(unittest.TestCase): + """End-to-end tests with a real ONNX model. + + Skipped automatically when onnx is not installed or the voice cannot be + downloaded (no network, CI without model cache, etc.). + """ + + @classmethod + def setUpClass(cls): + try: + import onnx # noqa: F401 + except ImportError: + raise unittest.SkipTest("onnx not installed") + + try: + src_onnx, config_path = _ensure_model() + except unittest.SkipTest: + raise + except Exception as e: + raise unittest.SkipTest(f"Could not set up model: {e}") + + _patch_alignment(src_onnx, ALIGNED_MODEL) + cls.voice = TTSVoice.load(ALIGNED_MODEL, config_path) + cls.voice_no_align = TTSVoice.load(src_onnx, config_path) + + def test_synthesize_produces_audio(self): + chunks = list(self.voice.synthesize("kaixo")) + self.assertGreater(len(chunks), 0) + self.assertGreater(len(chunks[0].audio_float_array), 0) + + def test_without_alignments_fields_are_none(self): + chunk = list(self.voice.synthesize("kaixo", include_alignments=False))[0] + self.assertIsNone(chunk.phoneme_id_samples) + self.assertIsNone(chunk.phoneme_alignments) + + def test_phonemes_and_ids_always_present(self): + chunk = list(self.voice.synthesize("kaixo", include_alignments=False))[0] + self.assertGreater(len(chunk.phonemes), 0) + self.assertGreater(len(chunk.phoneme_ids), 0) + + def test_with_alignments_populated(self): + chunk = list(self.voice.synthesize("kaixo", include_alignments=True))[0] + self.assertIsNotNone(chunk.phoneme_id_samples) + self.assertIsNotNone(chunk.phoneme_alignments) + self.assertGreater(len(chunk.phoneme_alignments), 0) + + def test_alignment_total_equals_audio_length(self): + """Sum of per-phoneme sample counts must equal the audio frame count.""" + chunk = list(self.voice.synthesize("kaixo", include_alignments=True))[0] + total = sum(a.num_samples for a in chunk.phoneme_alignments) + self.assertEqual(total, len(chunk.audio_float_array)) + + def test_alignment_covers_all_phonemes(self): + """Every real phoneme from phonemization appears in the alignment list.""" + chunk = list(self.voice.synthesize("kaixo mundua", include_alignments=True))[0] + aligned_phonemes = [a.phoneme for a in chunk.phoneme_alignments] + for p in chunk.phonemes: + self.assertIn(p, aligned_phonemes) + + def test_all_durations_positive(self): + chunk = list(self.voice.synthesize("kaixo mundua", include_alignments=True))[0] + for a in chunk.phoneme_alignments: + self.assertGreater(a.num_samples, 0, + f"phoneme {a.phoneme!r} has zero duration") + + def test_model_without_alignment_output_gives_none(self): + """Unpatched model (single output) gives phoneme_alignments=None.""" + chunk = list(self.voice_no_align.synthesize("kaixo", include_alignments=True))[0] + self.assertIsNone(chunk.phoneme_id_samples) + self.assertIsNone(chunk.phoneme_alignments) + + def test_multiple_sentences(self): + """Multi-sentence input: each chunk is independently aligned.""" + chunks = list(self.voice.synthesize( + "kaixo mundua. zer moduz zaude?", include_alignments=True + )) + for chunk in chunks: + if chunk.phoneme_alignments: + total = sum(a.num_samples for a in chunk.phoneme_alignments) + self.assertEqual(total, len(chunk.audio_float_array)) + + def test_alignment_phoneme_strings_are_real(self): + """Alignment phoneme strings are non-empty and not all '?'.""" + chunk = list(self.voice.synthesize("kaixo", include_alignments=True))[0] + phonemes = [a.phoneme for a in chunk.phoneme_alignments] + self.assertTrue(all(p for p in phonemes), "empty phoneme string in alignments") + unknown = [p for p in phonemes if p == "?"] + self.assertEqual(unknown, [], f"unknown ids in alignment: {phonemes}") + + +if __name__ == "__main__": + unittest.main()