[TTS][Magpietts] Unify Longform and Standard Inference logic#15375
[TTS][Magpietts] Unify Longform and Standard Inference logic#15375subhankar-ghosh wants to merge 17 commits intomainfrom
Conversation
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
Signed-off-by: subhankar-ghosh <subhankar-ghosh@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Refactors MagpieTTS inference to use a single “chunked” inference path for both short and long texts, with dataset-driven automatic sentence chunking based on per-sample language thresholds.
Changes:
- Introduces language-aware thresholding + unified
chunk_text_for_inference()chunking utility (replacing prior longform detection logic). - Replaces
LongFormTTSInferenceDatasetwithChunkedTTSInferenceDatasetand updates the inference runner to always use the unified multi/single-chunk loop. - Updates CLI/example script to remove explicit longform args and align with the unified inference flow.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
nemo/collections/tts/parts/utils/tts_dataset_utils.py |
Adds language-aware sentence splitting, thresholds, tokenizer mapping, and unified chunking helper. |
nemo/collections/tts/data/text_to_speech_dataset.py |
Replaces longform inference dataset with unified chunked inference dataset + mixed-chunk collation. |
nemo/collections/tts/modules/magpietts_inference/inference.py |
Removes standard/longform branching; always runs unified chunk loop via generate_speech(). |
nemo/collections/tts/models/magpietts.py |
Renames longform state/config to chunked equivalents; updates do_tts() to the unified chunked generation path. |
examples/tts/magpietts_inference.py |
Removes longform CLI controls and updates messaging for unified chunking behavior. |
tests/collections/tts/parts/utils/test_tts_dataset_utils.py |
Adds unit tests for new thresholds and unified chunking helper. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
|
The github UI still says that there are conflicts |
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
Signed-off-by: subhankar-ghosh <subhankar-ghosh@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Signed-off-by: Subhankar Ghosh <subhankar2321@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Signed-off-by: Subhankar Ghosh <subhankar2321@gmail.com>
| import wandb | ||
| from hydra.utils import instantiate | ||
| from lightning.pytorch import Trainer | ||
| from lhotse.serialization import load_yaml |
Check notice
Code scanning / CodeQL
Unused import Note
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI 4 days ago
To fix an unused import, we simply remove the import statement (or the specific symbol) that is not used anywhere in the file. This reduces unnecessary dependencies and cleans up the code without affecting runtime behavior.
In this case, the best fix is to delete the import line from lhotse.serialization import load_yaml at line 28 in nemo/collections/tts/models/magpietts.py. No other code changes are necessary, since we are not altering any used functionality and there are no visible references to load_yaml. Ensure that only this line is removed and that all other imports remain untouched.
| @@ -25,7 +25,6 @@ | ||
| import torch | ||
| import wandb | ||
| from hydra.utils import instantiate | ||
| from lhotse.serialization import load_yaml | ||
| from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger | ||
|
|
||
| from omegaconf import DictConfig, OmegaConf, open_dict |
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
…o into magpietts_longform_unify
Signed-off-by: Subhankar Ghosh <subhankarg@nvidia.com>
| # First try sample's tokenizer_names (from dataset config) | ||
| if data.tokenizer_names is not None: | ||
| return data.tokenizer_names[0] # Use first (deterministic for inference) |
There was a problem hiding this comment.
Where does this tokenizer_names parameter come from? It seems to assume that you read it from the dataset json which doesn't seem ideal
There was a problem hiding this comment.
And do we only support this in the non-Lhotse path? What happens if we try a Lhotse dataset?
There was a problem hiding this comment.
We do inference on non-Lhotse path only.
|
[🤖]: Hi @subhankar-ghosh 👋, We wanted to let you know that a CICD pipeline for this PR just finished successfully. So it might be time to merge this PR or get some approvals. |
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
…o into magpietts_longform_unify
| chunk_state = self.model.create_longform_chunk_state(batch_size=batch_size) | ||
| # Clear stale KV cache from prior inference calls (e.g., the previous batch or dataset | ||
| # may have left with populated tensors). | ||
| print(f"Resetting KV cache for decoder: {self.model.use_kv_cache_for_inference}") |
There was a problem hiding this comment.
@XuesongYang Let me know if this piece of code looks good to you.
There was a problem hiding this comment.
LGTM!
@rfejgin We never reset cache for local transformer either for standard inference or longform inference before. Do we always disable the kv cache, or depends on how many frames stacked?
There was a problem hiding this comment.
Thanks for being careful about this @XuesongYang and @subhankar-ghosh. Still, I think no special handling is needed for the local transformer since the LT already resets its kv cache automatically every timestep (or frame stack), since separate timesteps are completely independent for the LT. The reset step is at the start of local_transformer_sample_autoregressive() and local_transformer_sample_maskgit().
Do we always disable the kv cache, or depends on how many frames stacked?
It depends on the LT type. For Maskgit, we keep KV cache off because the that type of LT is non-causal, which makes standard KV caching impossible. For autoregressive LT it the kv cache is on and we reset it on every timestep (i.e. once per frame or stack).
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
This pull request refactors and unifies the text chunking and inference logic for TTS (Text-to-Speech) in the MagpieTTS pipeline. The main change is the replacement of the previous "longform" inference logic with a new, language-aware, unified chunked inference path. This affects dataset preparation, model state management, argument parsing, and the inference runner, making the codebase simpler and more robust for both short and long texts.
Key changes:
Unified Inference and Text Chunking
examples/tts/magpietts_inference.py,nemo/collections/tts/data/text_to_speech_dataset.py,nemo/collections/tts/models/magpietts.py) [1] [2] [3]--longform_mode,--longform_word_threshold, etc.), simplifying the inference interface. (examples/tts/magpietts_inference.py) [1] [2] [3]Dataset and Collation Refactor
ChunkedTTSInferenceDataset(replacingLongFormTTSInferenceDataset) with per-sample, language-aware chunking and tokenizer selection. The dataset now automatically decides chunking strategy based on language and text length. (nemo/collections/tts/data/text_to_speech_dataset.py) [1] [2] [3]collate_fnto handle variable-length chunked batches, padding as needed, and to generalize beyond the previous longform-specific logic. (nemo/collections/tts/data/text_to_speech_dataset.py) [1] [2]Model and State Naming Consistency
LongformDecoderState→ChunkedDecoderState,LongformConfig→ChunkedInferenceConfig) throughout the model code for clarity and consistency with the new unified approach. (nemo/collections/tts/models/magpietts.py) [1] [2] [3] [4]_needs_longform_inferencemethod and all language threshold logic from the model, as chunking is now handled in a unified, language-aware way. (nemo/collections/tts/models/magpietts.py)Utility and Import Updates
nemo/collections/tts/models/magpietts.py,nemo/collections/tts/data/text_to_speech_dataset.py) [1] [2]These changes make the TTS inference pipeline easier to use and maintain, while improving support for multilingual and variable-length text inputs.
Collection: TTS
Changelog
Usage
# Add a code snippet demonstrating how to use thisGitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information