Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 10 additions & 42 deletions examples/tts/magpietts_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ def run_inference_and_evaluation(
) -> Tuple[Optional[float], Optional[float]]:
"""Run inference and optional evaluation on specified datasets.

Longform inference is automatically detected based on text characteristics
when longform_mode="auto" (default). Use longform_mode="always" or "never"
for explicit control.
Uses unified inference path with automatic text chunking based on
per-sample language thresholds. Short texts are processed as single chunks,
long texts are automatically split into sentences.

Args:
model_config: Configuration for loading the model.
Expand Down Expand Up @@ -202,8 +202,8 @@ def run_inference_and_evaluation(
f"{checkpoint_name}_{moe_info}{inference_config.build_identifier()}_SV_{eval_config.sv_model}"
)

# Create inference runner (auto-detects longform based on config.longform_mode)
logging.info(f"Longform mode: {inference_config.longform_mode}")
# Create inference runner (uses unified path with automatic text chunking)
logging.info("Using unified inference with automatic text chunking based on language thresholds")
runner = MagpieInferenceRunner(model, inference_config)

# Tracking metrics across datasets
Expand Down Expand Up @@ -462,25 +462,6 @@ def create_argument_parser() -> argparse.ArgumentParser:
infer_group.add_argument(f"--{field.name}", **extra_args)
infer_group.add_argument('--batch_size', type=int, default=32)
infer_group.add_argument('--use_cfg', action='store_true', help='Enable classifier-free guidance')
infer_group.add_argument(
'--longform_mode',
type=str,
default='auto',
choices=['auto', 'always', 'never'],
help='Longform inference mode: auto (detect from text), always, or never',
)
infer_group.add_argument(
'--longform_word_threshold',
type=int,
default=40,
help='Word threshold for auto-detection of longform text',
)
infer_group.add_argument(
'--longform_max_decoder_steps',
type=int,
default=50000,
help='Maximum decoder steps for longform inference',
)

# Local transformer / MaskGit arguments
infer_group.add_argument('--use_local_transformer', action='store_true')
Expand Down Expand Up @@ -548,26 +529,15 @@ def main(argv=None):
parser.error("You must provide either:\n 1. --hparams_files and --checkpoint_files\n 2. --nemo_files")

# Build configurations
# Use higher max_decoder_steps for longform inference when mode is 'always'
if args.longform_mode == 'always':
max_decoder_steps = args.longform_max_decoder_steps
elif args.longform_mode == 'auto':
# Use longform steps if any text appears long (will be checked in runner)
max_decoder_steps = args.longform_max_decoder_steps
else: # 'never'
max_decoder_steps = 440
model_inference_parameters = {}
for field in fields(ModelInferenceParameters):
field = field.name
if field == "max_decoder_steps":
model_inference_parameters[field] = max_decoder_steps
continue
arg_from_cmdline = vars(args)[field]
field_name = field.name
arg_from_cmdline = vars(args)[field_name]
if arg_from_cmdline is not None:
if field in ["estimate_alignment_from_layers", "apply_prior_to_layers"]:
model_inference_parameters[field] = parse_layer_list(vars(args)[field])
if field_name in ["estimate_alignment_from_layers", "apply_prior_to_layers"]:
model_inference_parameters[field_name] = parse_layer_list(arg_from_cmdline)
else:
model_inference_parameters[field] = vars(args)[field]
model_inference_parameters[field_name] = arg_from_cmdline

inference_config = InferenceConfig(
model_inference_parameters=ModelInferenceParameters.from_dict(model_inference_parameters),
Expand All @@ -579,8 +549,6 @@ def main(argv=None):
maskgit_noise_scale=args.maskgit_noise_scale,
maskgit_fixed_schedule=args.maskgit_fixed_schedule,
maskgit_sampling_type=args.maskgit_sampling_type,
longform_mode=args.longform_mode,
longform_word_threshold=args.longform_word_threshold,
)

eval_config = EvaluationConfig(
Expand Down
90 changes: 63 additions & 27 deletions nemo/collections/tts/data/text_to_speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
_read_audio,
beta_binomial_prior_distribution,
chunk_and_tokenize_text_by_sentence,
chunk_text_for_inference,
filter_dataset_by_duration,
get_tokenizer_for_language,
get_weighted_sampler,
load_audio,
stack_tensors,
Expand Down Expand Up @@ -785,23 +787,25 @@ def collate_fn(self, batch: List[dict]):
return {"chosen": chosen_collated, "rejected": rejected_collated}


class LongFormTTSInferenceDataset(MagpieTTSDataset):
class ChunkedTTSInferenceDataset(MagpieTTSDataset):
"""
Dataset for longform TTS inference with sentence-level text chunking.
Unified dataset for TTS inference with automatic text chunking.

Inherits from MagpieTTSDataset to reuse context audio loading, text conditioning,
and other preprocessing logic. Adds sentence-level text chunking on top.
and other preprocessing logic. Uses language-aware chunking to automatically
decide whether to split text into sentences:
- Short text (below language threshold): returns single chunk
- Long text (above language threshold): returns multiple sentence chunks (multi-chunk)

Both language (for threshold) and tokenizer are determined per-sample:
- Language from manifest's 'language' field
- Tokenizer from sample's tokenizer_names or mapped from language

Args:
dataset_meta: Dataset metadata dictionary (same format as MagpieTTSDataset).
sample_rate: Audio sample rate.
tokenizer_name: Name of the tokenizer to use for sentence chunking.
codec_model_samples_per_frame: Samples per codec frame.
eos_id: End-of-sequence token ID.
audio_bos_id: Audio BOS token ID (for target audio).
audio_eos_id: Audio EOS token ID (for target audio).
context_audio_bos_id: Context audio BOS token ID.
context_audio_eos_id: Context audio EOS token ID.
num_audio_codebooks: Number of audio codebooks.
context_duration_min: Minimum context duration in seconds.
context_duration_max: Maximum context duration in seconds.
Expand All @@ -815,7 +819,6 @@ def __init__(
self,
dataset_meta: Dict[str, Any],
sample_rate: int,
tokenizer_name: str,
codec_model_samples_per_frame: int,
eos_id: int,
num_audio_codebooks: int,
Expand Down Expand Up @@ -844,41 +847,71 @@ def __init__(
dataset_type='test',
**kwargs,
)
self.tokenizer_name = tokenizer_name

def _get_tokenizer_name(self, data, language: str) -> str:
"""Get tokenizer name for a sample, from tokenizer_names or language mapping.

Args:
data: DatasetSample with optional tokenizer_names field.
language: Language code from manifest.

Returns:
Tokenizer name to use for encoding.
"""
# First try sample's tokenizer_names (from dataset config)
if data.tokenizer_names is not None:
return data.tokenizer_names[0] # Use first (deterministic for inference)
Comment on lines +861 to +863
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this tokenizer_names parameter come from? It seems to assume that you read it from the dataset json which doesn't seem ideal

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And do we only support this in the non-Lhotse path? What happens if we try a Lhotse dataset?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do inference on non-Lhotse path only.


# Fall back to centralized language-based mapping
if self.text_tokenizer is not None:
available = list(self.text_tokenizer.tokenizers.keys())
return get_tokenizer_for_language(language, available)

return "english_phoneme"

def __getitem__(self, idx: int) -> Dict[str, Any]:
"""
Add sentence chunking on top of parent's __getitem__.
Add automatic text chunking on top of parent's __getitem__.

Uses language-aware chunking that automatically decides whether to split:
- Short text (below threshold): returns as single chunk
- Long text (above threshold): returns as multiple sentence chunks

Both tokenizer and chunking threshold are determined per-sample based on
language and tokenizer configuration.

Returns:
Dictionary containing all parent fields plus:
- idx: Sample index
- chunked_tokens: List of tokenized text chunks (per sentence)
- chunked_tokens: List of tokenized text chunks (1 for short, N for long)
- chunked_tokens_len: List of token lengths
- entry: Original manifest entry
"""
# Get text for sentence chunking
# Get data sample for text and tokenizer info
data = self.data_samples[idx]
text = data.text # entry.get("normalized_text", entry.get("text", ""))
text = data.text

# Call parent to get ALL the context audio, text conditioning, etc.
example = super().__getitem__(idx)

# Sentence chunking (longform-specific)
chunked_tokens, chunked_tokens_len, _ = chunk_and_tokenize_text_by_sentence(
text,
self.tokenizer_name,
self.text_tokenizer,
self.eos_id,
language=example['language'],
# Get language and tokenizer per-sample
language = example["language"]
tokenizer_name = self._get_tokenizer_name(data, language)

# Unified chunking: automatically decides whether to split based on language threshold
chunked_tokens, chunked_tokens_len, _ = chunk_text_for_inference(
text=text,
language=language,
tokenizer_name=tokenizer_name,
text_tokenizer=self.text_tokenizer,
eos_token_id=self.eos_id,
)

# Handle empty text edge case
if not chunked_tokens:
chunked_tokens = [torch.tensor([self.eos_id], dtype=torch.int32)]
chunked_tokens_len = [1]

# Add longform-specific fields
# Add chunking-related fields
example['idx'] = idx
example['chunked_tokens'] = chunked_tokens
example['chunked_tokens_len'] = chunked_tokens_len
Expand All @@ -887,15 +920,18 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:

def collate_fn(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Collate function for batching longform samples.
Collate function for batching unified inference samples.

Calls parent's collate_fn to handle context audio, text conditioning, etc.,
then adds longform-specific fields (chunked_tokens).
then adds chunking-related fields (chunked_tokens).

Handles mixed batches where samples have different numbers of chunks by
padding shorter samples with EOS tokens.
"""
# Call parent's collate_fn to handle all standard fields
batch_dict = super().collate_fn(batch)

# Add longform-specific fields
# Add chunking-related fields
indices = []
chunked_tokens_list = []
chunked_tokens_lens_list = []
Expand All @@ -916,7 +952,7 @@ def collate_fn(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
chunked_tokens_list.append(padded_tokens)
chunked_tokens_lens_list.append(padded_lens)

# Add longform-specific fields to batch_dict
# Add chunking-related fields to batch_dict
batch_dict['idx'] = indices
batch_dict['chunked_tokens'] = chunked_tokens_list
batch_dict['chunked_tokens_lens'] = chunked_tokens_lens_list
Expand Down
Loading
Loading