Skip to content

Comments

[TTS][Magpietts] Unify Longform and Standard Inference logic#15375

Open
subhankar-ghosh wants to merge 17 commits intomainfrom
magpietts_longform_unify
Open

[TTS][Magpietts] Unify Longform and Standard Inference logic#15375
subhankar-ghosh wants to merge 17 commits intomainfrom
magpietts_longform_unify

Conversation

@subhankar-ghosh
Copy link
Collaborator

Important

The Update branch button must only be pressed in very rare occassions.
An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.

What does this PR do ?

This pull request refactors and unifies the text chunking and inference logic for TTS (Text-to-Speech) in the MagpieTTS pipeline. The main change is the replacement of the previous "longform" inference logic with a new, language-aware, unified chunked inference path. This affects dataset preparation, model state management, argument parsing, and the inference runner, making the codebase simpler and more robust for both short and long texts.

Key changes:

Unified Inference and Text Chunking

  • Replaced the old longform inference logic with a unified, automatic text chunking approach that determines chunking based on per-sample language thresholds. Short texts are processed as single chunks, while long texts are split into sentences automatically. (examples/tts/magpietts_inference.py, nemo/collections/tts/data/text_to_speech_dataset.py, nemo/collections/tts/models/magpietts.py) [1] [2] [3]
  • Removed all command-line arguments related to explicit longform control (--longform_mode, --longform_word_threshold, etc.), simplifying the inference interface. (examples/tts/magpietts_inference.py) [1] [2] [3]

Dataset and Collation Refactor

  • Introduced ChunkedTTSInferenceDataset (replacing LongFormTTSInferenceDataset) with per-sample, language-aware chunking and tokenizer selection. The dataset now automatically decides chunking strategy based on language and text length. (nemo/collections/tts/data/text_to_speech_dataset.py) [1] [2] [3]
  • Updated the dataset's collate_fn to handle variable-length chunked batches, padding as needed, and to generalize beyond the previous longform-specific logic. (nemo/collections/tts/data/text_to_speech_dataset.py) [1] [2]

Model and State Naming Consistency

  • Renamed all "longform" classes and configs to "chunked" (e.g., LongformDecoderStateChunkedDecoderState, LongformConfigChunkedInferenceConfig) throughout the model code for clarity and consistency with the new unified approach. (nemo/collections/tts/models/magpietts.py) [1] [2] [3] [4]
  • Removed the _needs_longform_inference method and all language threshold logic from the model, as chunking is now handled in a unified, language-aware way. (nemo/collections/tts/models/magpietts.py)

Utility and Import Updates

  • Added and updated utility imports for chunked inference and tokenizer selection to support the new pipeline. (nemo/collections/tts/models/magpietts.py, nemo/collections/tts/data/text_to_speech_dataset.py) [1] [2]

These changes make the TTS inference pipeline easier to use and maintain, while improving support for multilingual and variable-length text inputs.

Collection: TTS

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
Signed-off-by: subhankar-ghosh <subhankar-ghosh@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors MagpieTTS inference to use a single “chunked” inference path for both short and long texts, with dataset-driven automatic sentence chunking based on per-sample language thresholds.

Changes:

  • Introduces language-aware thresholding + unified chunk_text_for_inference() chunking utility (replacing prior longform detection logic).
  • Replaces LongFormTTSInferenceDataset with ChunkedTTSInferenceDataset and updates the inference runner to always use the unified multi/single-chunk loop.
  • Updates CLI/example script to remove explicit longform args and align with the unified inference flow.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
nemo/collections/tts/parts/utils/tts_dataset_utils.py Adds language-aware sentence splitting, thresholds, tokenizer mapping, and unified chunking helper.
nemo/collections/tts/data/text_to_speech_dataset.py Replaces longform inference dataset with unified chunked inference dataset + mixed-chunk collation.
nemo/collections/tts/modules/magpietts_inference/inference.py Removes standard/longform branching; always runs unified chunk loop via generate_speech().
nemo/collections/tts/models/magpietts.py Renames longform state/config to chunked equivalents; updates do_tts() to the unified chunked generation path.
examples/tts/magpietts_inference.py Removes longform CLI controls and updates messaging for unified chunking behavior.
tests/collections/tts/parts/utils/test_tts_dataset_utils.py Adds unit tests for new thresholds and unified chunking helper.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@blisc
Copy link
Collaborator

blisc commented Feb 10, 2026

  • Can you resolve conflicts?
  • Can you add a unit test for magpietts.generate_speech that contains a batch of short and long texts?

Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
@blisc
Copy link
Collaborator

blisc commented Feb 12, 2026

The github UI still says that there are conflicts

subhankar-ghosh and others added 2 commits February 12, 2026 09:58
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
Signed-off-by: subhankar-ghosh <subhankar-ghosh@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Signed-off-by: Subhankar Ghosh <subhankar2321@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Signed-off-by: Subhankar Ghosh <subhankar2321@gmail.com>
import wandb
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from lhotse.serialization import load_yaml

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'load_yaml' is not used.

Copilot Autofix

AI 4 days ago

To fix an unused import, we simply remove the import statement (or the specific symbol) that is not used anywhere in the file. This reduces unnecessary dependencies and cleans up the code without affecting runtime behavior.

In this case, the best fix is to delete the import line from lhotse.serialization import load_yaml at line 28 in nemo/collections/tts/models/magpietts.py. No other code changes are necessary, since we are not altering any used functionality and there are no visible references to load_yaml. Ensure that only this line is removed and that all other imports remain untouched.

Suggested changeset 1
nemo/collections/tts/models/magpietts.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py
--- a/nemo/collections/tts/models/magpietts.py
+++ b/nemo/collections/tts/models/magpietts.py
@@ -25,7 +25,6 @@
 import torch
 import wandb
 from hydra.utils import instantiate
-from lhotse.serialization import load_yaml
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 
 from omegaconf import DictConfig, OmegaConf, open_dict
EOF
@@ -25,7 +25,6 @@
import torch
import wandb
from hydra.utils import instantiate
from lhotse.serialization import load_yaml
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

from omegaconf import DictConfig, OmegaConf, open_dict
Copilot is powered by AI and may make mistakes. Always verify output.
Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
subhankar-ghosh and others added 2 commits February 17, 2026 02:57
Signed-off-by: Subhankar Ghosh <subhankarg@nvidia.com>
Comment on lines +861 to +863
# First try sample's tokenizer_names (from dataset config)
if data.tokenizer_names is not None:
return data.tokenizer_names[0] # Use first (deterministic for inference)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this tokenizer_names parameter come from? It seems to assume that you read it from the dataset json which doesn't seem ideal

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And do we only support this in the non-Lhotse path? What happens if we try a Lhotse dataset?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do inference on non-Lhotse path only.

@github-actions
Copy link
Contributor

[🤖]: Hi @subhankar-ghosh 👋,

We wanted to let you know that a CICD pipeline for this PR just finished successfully.

So it might be time to merge this PR or get some approvals.

//cc @chtruong814 @ko3n1g @pablo-garay @thomasdhc

Signed-off-by: subhankar-ghosh <subhankar2321@gmail.com>
chunk_state = self.model.create_longform_chunk_state(batch_size=batch_size)
# Clear stale KV cache from prior inference calls (e.g., the previous batch or dataset
# may have left with populated tensors).
print(f"Resetting KV cache for decoder: {self.model.use_kv_cache_for_inference}")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XuesongYang Let me know if this piece of code looks good to you.

Copy link
Collaborator

@XuesongYang XuesongYang Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@rfejgin We never reset cache for local transformer either for standard inference or longform inference before. Do we always disable the kv cache, or depends on how many frames stacked?

Copy link
Collaborator

@rfejgin rfejgin Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for being careful about this @XuesongYang and @subhankar-ghosh. Still, I think no special handling is needed for the local transformer since the LT already resets its kv cache automatically every timestep (or frame stack), since separate timesteps are completely independent for the LT. The reset step is at the start of local_transformer_sample_autoregressive() and local_transformer_sample_maskgit().

Do we always disable the kv cache, or depends on how many frames stacked?

It depends on the LT type. For Maskgit, we keep KV cache off because the that type of LT is non-causal, which makes standard KV caching impossible. For autoregressive LT it the kv cache is on and we reset it on every timestep (i.e. once per frame or stack).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants