Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/tts/evalset_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
"manifest_path": "/home/TestData/an4_dataset/an4_val_context_v1_tiny.json",
"audio_dir": "/",
"feature_dir": null
},
"an4_val_ci_longform_tiny": {
"manifest_path": "/home/TestData/an4_dataset/an4_val_context_v1_longform_tiny.json",
"audio_dir": "/",
"feature_dir": null
}
}

52 changes: 10 additions & 42 deletions examples/tts/magpietts_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def run_inference_and_evaluation(
) -> Tuple[Optional[float], Optional[float]]:
"""Run inference and optional evaluation on specified datasets.

Longform inference is automatically detected based on text characteristics
when longform_mode="auto" (default). Use longform_mode="always" or "never"
for explicit control.
Uses unified inference path with automatic text chunking based on
per-sample language thresholds. Short texts are processed as single chunks,
long texts are automatically split into sentences.

Args:
model_config: Configuration for loading the model.
Expand Down Expand Up @@ -192,8 +192,8 @@ def run_inference_and_evaluation(
# Build full checkpoint identifier
full_checkpoint_name = f"{checkpoint_name}_{inference_config.build_identifier()}_SV_{eval_config.sv_model}"

# Create inference runner (auto-detects longform based on config.longform_mode)
logging.info(f"Longform mode: {inference_config.longform_mode}")
# Create inference runner (uses unified path with automatic text chunking)
logging.info("Using unified inference with automatic text chunking based on language thresholds")
runner = MagpieInferenceRunner(model, inference_config)

# Tracking metrics across datasets
Expand Down Expand Up @@ -445,25 +445,6 @@ def create_argument_parser() -> argparse.ArgumentParser:
infer_group.add_argument(f"--{field.name}", **extra_args)
infer_group.add_argument('--batch_size', type=int, default=32)
infer_group.add_argument('--use_cfg', action='store_true', help='Enable classifier-free guidance')
infer_group.add_argument(
'--longform_mode',
type=str,
default='auto',
choices=['auto', 'always', 'never'],
help='Longform inference mode: auto (detect from text), always, or never',
)
infer_group.add_argument(
'--longform_word_threshold',
type=int,
default=40,
help='Word threshold for auto-detection of longform text',
)
infer_group.add_argument(
'--longform_max_decoder_steps',
type=int,
default=50000,
help='Maximum decoder steps for longform inference',
)

# Local transformer / MaskGit arguments
infer_group.add_argument('--use_local_transformer', action='store_true')
Expand Down Expand Up @@ -531,26 +512,15 @@ def main(argv=None):
parser.error("You must provide either:\n 1. --hparams_files and --checkpoint_files\n 2. --nemo_files")

# Build configurations
# Use higher max_decoder_steps for longform inference when mode is 'always'
if args.longform_mode == 'always':
max_decoder_steps = args.longform_max_decoder_steps
elif args.longform_mode == 'auto':
# Use longform steps if any text appears long (will be checked in runner)
max_decoder_steps = args.longform_max_decoder_steps
else: # 'never'
max_decoder_steps = 440
model_inference_parameters = {}
for field in fields(ModelInferenceParameters):
field = field.name
if field == "max_decoder_steps":
model_inference_parameters[field] = max_decoder_steps
continue
arg_from_cmdline = vars(args)[field]
field_name = field.name
arg_from_cmdline = vars(args)[field_name]
if arg_from_cmdline is not None:
if field in ["estimate_alignment_from_layers", "apply_prior_to_layers"]:
model_inference_parameters[field] = parse_layer_list(vars(args)[field])
if field_name in ["estimate_alignment_from_layers", "apply_prior_to_layers"]:
model_inference_parameters[field_name] = parse_layer_list(arg_from_cmdline)
else:
model_inference_parameters[field] = vars(args)[field]
model_inference_parameters[field_name] = arg_from_cmdline

inference_config = InferenceConfig(
model_inference_parameters=ModelInferenceParameters.from_dict(model_inference_parameters),
Expand All @@ -562,8 +532,6 @@ def main(argv=None):
maskgit_noise_scale=args.maskgit_noise_scale,
maskgit_fixed_schedule=args.maskgit_fixed_schedule,
maskgit_sampling_type=args.maskgit_sampling_type,
longform_mode=args.longform_mode,
longform_word_threshold=args.longform_word_threshold,
)

eval_config = EvaluationConfig(
Expand Down
90 changes: 63 additions & 27 deletions nemo/collections/tts/data/text_to_speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
_read_audio,
beta_binomial_prior_distribution,
chunk_and_tokenize_text_by_sentence,
chunk_text_for_inference,
filter_dataset_by_duration,
get_tokenizer_for_language,
get_weighted_sampler,
load_audio,
stack_tensors,
Expand Down Expand Up @@ -785,23 +787,25 @@ def collate_fn(self, batch: List[dict]):
return {"chosen": chosen_collated, "rejected": rejected_collated}


class LongFormTTSInferenceDataset(MagpieTTSDataset):
class ChunkedTTSInferenceDataset(MagpieTTSDataset):
"""
Dataset for longform TTS inference with sentence-level text chunking.
Unified dataset for TTS inference with automatic text chunking.

Inherits from MagpieTTSDataset to reuse context audio loading, text conditioning,
and other preprocessing logic. Adds sentence-level text chunking on top.
and other preprocessing logic. Uses language-aware chunking to automatically
decide whether to split text into sentences:
- Short text (below language threshold): returns single chunk
- Long text (above language threshold): returns multiple sentence chunks (multi-chunk)

Both language (for threshold) and tokenizer are determined per-sample:
- Language from manifest's 'language' field
- Tokenizer from sample's tokenizer_names or mapped from language

Args:
dataset_meta: Dataset metadata dictionary (same format as MagpieTTSDataset).
sample_rate: Audio sample rate.
tokenizer_name: Name of the tokenizer to use for sentence chunking.
codec_model_samples_per_frame: Samples per codec frame.
eos_id: End-of-sequence token ID.
audio_bos_id: Audio BOS token ID (for target audio).
audio_eos_id: Audio EOS token ID (for target audio).
context_audio_bos_id: Context audio BOS token ID.
context_audio_eos_id: Context audio EOS token ID.
num_audio_codebooks: Number of audio codebooks.
context_duration_min: Minimum context duration in seconds.
context_duration_max: Maximum context duration in seconds.
Expand All @@ -815,7 +819,6 @@ def __init__(
self,
dataset_meta: Dict[str, Any],
sample_rate: int,
tokenizer_name: str,
codec_model_samples_per_frame: int,
eos_id: int,
num_audio_codebooks: int,
Expand Down Expand Up @@ -844,41 +847,71 @@ def __init__(
dataset_type='test',
**kwargs,
)
self.tokenizer_name = tokenizer_name

def _get_tokenizer_name(self, data, language: str) -> str:
"""Get tokenizer name for a sample, from tokenizer_names or language mapping.

Args:
data: DatasetSample with optional tokenizer_names field.
language: Language code from manifest.

Returns:
Tokenizer name to use for encoding.
"""
# First try sample's tokenizer_names (from dataset config)
if data.tokenizer_names is not None:
return data.tokenizer_names[0] # Use first (deterministic for inference)
Comment on lines +861 to +863
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this tokenizer_names parameter come from? It seems to assume that you read it from the dataset json which doesn't seem ideal

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And do we only support this in the non-Lhotse path? What happens if we try a Lhotse dataset?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do inference on non-Lhotse path only.


# Fall back to centralized language-based mapping
if self.text_tokenizer is not None:
available = list(self.text_tokenizer.tokenizers.keys())
return get_tokenizer_for_language(language, available)

return "english_phoneme"

def __getitem__(self, idx: int) -> Dict[str, Any]:
"""
Add sentence chunking on top of parent's __getitem__.
Add automatic text chunking on top of parent's __getitem__.

Uses language-aware chunking that automatically decides whether to split:
- Short text (below threshold): returns as single chunk
- Long text (above threshold): returns as multiple sentence chunks

Both tokenizer and chunking threshold are determined per-sample based on
language and tokenizer configuration.

Returns:
Dictionary containing all parent fields plus:
- idx: Sample index
- chunked_tokens: List of tokenized text chunks (per sentence)
- chunked_tokens: List of tokenized text chunks (1 for short, N for long)
- chunked_tokens_len: List of token lengths
- entry: Original manifest entry
"""
# Get text for sentence chunking
# Get data sample for text and tokenizer info
data = self.data_samples[idx]
text = data.text # entry.get("normalized_text", entry.get("text", ""))
text = data.text

# Call parent to get ALL the context audio, text conditioning, etc.
example = super().__getitem__(idx)

# Sentence chunking (longform-specific)
chunked_tokens, chunked_tokens_len, _ = chunk_and_tokenize_text_by_sentence(
text,
self.tokenizer_name,
self.text_tokenizer,
self.eos_id,
language=example['language'],
# Get language and tokenizer per-sample
language = example["language"]
tokenizer_name = self._get_tokenizer_name(data, language)

# Unified chunking: automatically decides whether to split based on language threshold
chunked_tokens, chunked_tokens_len, _ = chunk_text_for_inference(
text=text,
language=language,
tokenizer_name=tokenizer_name,
text_tokenizer=self.text_tokenizer,
eos_token_id=self.eos_id,
)

# Handle empty text edge case
if not chunked_tokens:
chunked_tokens = [torch.tensor([self.eos_id], dtype=torch.int32)]
chunked_tokens_len = [1]

# Add longform-specific fields
# Add chunking-related fields
example['idx'] = idx
example['chunked_tokens'] = chunked_tokens
example['chunked_tokens_len'] = chunked_tokens_len
Expand All @@ -887,15 +920,18 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:

def collate_fn(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Collate function for batching longform samples.
Collate function for batching unified inference samples.

Calls parent's collate_fn to handle context audio, text conditioning, etc.,
then adds longform-specific fields (chunked_tokens).
then adds chunking-related fields (chunked_tokens).

Handles mixed batches where samples have different numbers of chunks by
padding shorter samples with EOS tokens.
"""
# Call parent's collate_fn to handle all standard fields
batch_dict = super().collate_fn(batch)

# Add longform-specific fields
# Add chunking-related fields
indices = []
chunked_tokens_list = []
chunked_tokens_lens_list = []
Expand All @@ -916,7 +952,7 @@ def collate_fn(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
chunked_tokens_list.append(padded_tokens)
chunked_tokens_lens_list.append(padded_lens)

# Add longform-specific fields to batch_dict
# Add chunking-related fields to batch_dict
batch_dict['idx'] = indices
batch_dict['chunked_tokens'] = chunked_tokens_list
batch_dict['chunked_tokens_lens'] = chunked_tokens_lens_list
Expand Down
Loading
Loading