diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 1690efb748d0..713b49ab85b8 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -1106,7 +1106,7 @@ def _compute_offsets(hypothesis: Hypothesis, blank_id: int) -> List[Dict[str, Un Returns: List[Dict[str, Union[str, int]]]: A list of dictionaries, where each dictionary contains: - - "char": List[str] - The character/subword token + - "char": List[int] - The character/subword token ID - "start_offset": int - The start time index of the token - "end_offset": int - The end time index of the token @@ -1127,32 +1127,78 @@ def _compute_offsets(hypothesis: Hypothesis, blank_id: int) -> List[Dict[str, Un @staticmethod def _compute_offsets_tdt(hypothesis: Hypothesis, blank_id: int, *args) -> List[Dict[str, Union[str, int]]]: """ - Utility method that calculates the indidual time indices where a token starts and ends. + Utility method that calculates the individual time indices where a token starts and ends for TDT models. + + This method handles both START and END timestamp semantics. Beam search in TDT stores END timestamps + (timesteps + duration), while greedy decoding stores START timestamps. The method uses an explicit + _timestamp_semantics flag or a fallback heuristic to determine the correct interpretation. Args: - hypothesis: A Hypothesis object that contains `text` field that holds the character / subword token - emitted at a specific time step considering predicted durations of the previous tokens. + hypothesis: A Hypothesis object that contains `timestamp` and `token_duration` fields. + blank_id: The index representing the blank token in the vocabulary. + *args: Additional arguments (unused, for compatibility). Returns: List[Dict[str, Union[str, int]]]: A list of dictionaries, where each dictionary contains: - - "char": List[str] - The character/subword token + - "char": List[int] - The character/subword token ID - "start_offset": int - The start time index of the token - "end_offset": int - The end time index of the token **Note**: Blank tokens are not included in the offsets. """ - if isinstance(hypothesis.timestamp, torch.Tensor): - hypothesis.token_duration = hypothesis.token_duration.cpu().tolist() - + # Handle missing token_duration (compute from timestamp diffs) + if hypothesis.token_duration is None: + if hypothesis.timestamp is None or len(hypothesis.timestamp) == 0: + return [] + if isinstance(hypothesis.timestamp, torch.Tensor): + timestamps = hypothesis.timestamp.cpu().tolist() + else: + timestamps = list(hypothesis.timestamp) + durations = [] + prev = 0 + for curr in timestamps: + durations.append(max(0, curr - prev)) + prev = curr + hypothesis.token_duration = durations + + # Convert to lists if isinstance(hypothesis.timestamp, torch.Tensor): hypothesis.timestamp = hypothesis.timestamp.cpu().tolist() + if isinstance(hypothesis.token_duration, torch.Tensor): + hypothesis.token_duration = hypothesis.token_duration.cpu().tolist() - # Merge the results per token into a list of dictionaries - offsets = [ - {"char": [t], "start_offset": s, "end_offset": s + d} - for t, s, d in zip(hypothesis.y_sequence, hypothesis.timestamp, hypothesis.token_duration) - if t != blank_id - ] + # Determine timestamp semantics + semantics_flag = getattr(hypothesis, "_timestamp_semantics", None) + if semantics_flag == "end": + timestamps_are_ends = True + elif semantics_flag == "start": + timestamps_are_ends = False + else: + # Fallback heuristic: END timestamps never precede cumulative duration + timestamps_are_ends = False + if len(hypothesis.timestamp) > 0 and len(hypothesis.token_duration) > 0: + cumulative = 0 + matches_end_semantics = True + for ts, dur in zip(hypothesis.timestamp, hypothesis.token_duration): + cumulative += dur + if ts + 1 < cumulative: # 1-frame slack + matches_end_semantics = False + break + if matches_end_semantics and hypothesis.timestamp[0] >= hypothesis.token_duration[0]: + timestamps_are_ends = True + + # Compute offsets + offsets = [] + for t, ts, d in zip(hypothesis.y_sequence, hypothesis.timestamp, hypothesis.token_duration): + if t == blank_id: + continue + if timestamps_are_ends: + start_offset = ts - d + end_offset = ts + else: + start_offset = ts + end_offset = ts + d + offsets.append({"char": [t], "start_offset": start_offset, "end_offset": end_offset}) return offsets @staticmethod diff --git a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index e0af3b7623d7..7f407b8645c9 100644 --- a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -171,6 +171,13 @@ def __init__( self.next_timestamp = torch.zeros((batch_size, self.beam_size), device=device, dtype=torch.long) self.last_timestamp_lasts = torch.zeros((batch_size, self.beam_size), device=device, dtype=torch.long) + # For TDT models, also store token durations to enable correct timestamp offset computation + # Beam search stores END timestamps (timesteps + duration), but offset computation expects START + if self.model_type == ASRModelTypeEnum.TDT: + self.token_durations = torch.zeros( + (batch_size, self.beam_size, self._max_length), device=device, dtype=torch.long + ) + def clear_(self): """ Clears and resets the internal state of the object. @@ -199,6 +206,9 @@ def clear_(self): self.next_timestamp.fill_(0) self.last_timestamp_lasts.fill_(0) + if self.model_type == ASRModelTypeEnum.TDT: + self.token_durations.fill_(0) + def _allocate_more(self): """ Dynamically allocates more memory for the internal buffers. @@ -216,6 +226,9 @@ def _allocate_more(self): else: self.timestamps = torch.cat((self.timestamps, torch.zeros_like(self.timestamps)), dim=-1) + if self.model_type == ASRModelTypeEnum.TDT: + self.token_durations = torch.cat((self.token_durations, torch.zeros_like(self.token_durations)), dim=-1) + self._max_length *= 2 def add_results_( @@ -301,6 +314,13 @@ def add_results_no_checks_( index=self.current_lengths_wb.unsqueeze(-1), src=(timesteps + next_label_durations).unsqueeze(-1), ) + # Also store token durations for TDT (needed for correct offset computation) + durations_to_store = torch.where(is_extended, next_label_durations, 0) + self.token_durations.scatter_( + dim=-1, + index=self.current_lengths_wb.unsqueeze(-1), + src=durations_to_store.unsqueeze(-1), + ) torch.where(is_extended, timesteps + next_label_durations, timesteps, out=self.next_timestamp) torch.where( is_extended & (next_label_durations > 0), @@ -474,19 +494,28 @@ def to_hyps_list(self, score_norm: bool = True) -> list[Hypothesis]: max_idx = self.current_lengths_wb.max() - 1 timestamps = self.timestamps[..., 0, : max_idx + 1] transcripts = self.transcript_wb[..., 0, : max_idx + 1] - hypotheses = [ - Hypothesis( + + if self.model_type == ASRModelTypeEnum.TDT: + token_durations = self.token_durations[..., 0, : max_idx + 1] + + hypotheses = [] + for batch_idx in range(self.batch_size): + mask = self._create_transcripts_mask(transcripts[batch_idx]) + hyp = Hypothesis( score=scores[batch_idx], - y_sequence=transcripts[batch_idx][mask := self._create_transcripts_mask(transcripts[batch_idx])] - .cpu() - .detach() - .numpy(), + y_sequence=transcripts[batch_idx][mask].cpu().detach().numpy(), timestamp=timestamps[batch_idx][mask].cpu().detach().numpy(), alignments=None, dec_state=None, ) - for batch_idx in range(self.batch_size) - ] + + # Populate token_duration and mark semantics for TDT + if self.model_type == ASRModelTypeEnum.TDT: + hyp.token_duration = token_durations[batch_idx][mask].cpu().detach().numpy().tolist() + hyp._timestamp_semantics = "end" # Explicit marker + + hypotheses.append(hyp) + return hypotheses def to_nbest_hyps_list(self, score_norm: bool = True) -> list[NBestHypotheses]: @@ -506,27 +535,33 @@ def to_nbest_hyps_list(self, score_norm: bool = True) -> list[NBestHypotheses]: max_idx = self.current_lengths_wb.max() - 1 transcripts = self.transcript_wb[..., : max_idx + 1] timestamps = self.timestamps[..., : max_idx + 1] - hypotheses = [ - NBestHypotheses( - [ - Hypothesis( + + if self.model_type == ASRModelTypeEnum.TDT: + token_durations = self.token_durations[..., : max_idx + 1] + + hypotheses = [] + for batch_idx in range(self.batch_size): + batch_hyps = [] + for beam_idx in range(self.beam_size): + if scores[batch_idx][beam_idx] > INACTIVE_SCORE: + mask = self._create_transcripts_mask(transcripts[batch_idx][beam_idx]) + hyp = Hypothesis( score=scores[batch_idx][beam_idx], - y_sequence=transcripts[batch_idx][beam_idx][ - mask := self._create_transcripts_mask(transcripts[batch_idx][beam_idx]) - ] - .cpu() - .detach() - .numpy(), + y_sequence=transcripts[batch_idx][beam_idx][mask].cpu().detach().numpy(), timestamp=timestamps[batch_idx][beam_idx][mask].cpu().detach().numpy(), alignments=None, dec_state=None, ) - for beam_idx in range(self.beam_size) - if scores[batch_idx][beam_idx] > INACTIVE_SCORE - ] - ) - for batch_idx in range(self.batch_size) - ] + + # Populate token_duration and mark semantics for TDT + if self.model_type == ASRModelTypeEnum.TDT: + hyp.token_duration = token_durations[batch_idx][beam_idx][mask].cpu().detach().numpy().tolist() + hyp._timestamp_semantics = "end" # Explicit marker + + batch_hyps.append(hyp) + + hypotheses.append(NBestHypotheses(batch_hyps)) + return hypotheses def flatten_sort_(self, score_norm: bool = True): @@ -552,6 +587,17 @@ def flatten_sort_(self, score_norm: bool = True): max_idx = self.current_lengths_wb.max() - 1 ptrs = indices + # Sort token_durations BEFORE updating ptrs to avoid misalignment + if self.model_type == ASRModelTypeEnum.TDT: + token_durations_sorted = torch.zeros_like(self.token_durations) + temp_ptrs = indices + for idx in range(max_idx, -1, -1): + token_durations_sorted[..., idx].copy_( + self.token_durations[self.batch_indices.unsqueeze(-1), temp_ptrs, idx] + ) + temp_ptrs = self.transcript_wb_prev_ptr[self.batch_indices.unsqueeze(-1), temp_ptrs, idx] + self.token_durations.copy_(token_durations_sorted) + for idx in range(max_idx, -1, -1): self.transcript_wb[..., idx].copy_(self.transcript_wb[self.batch_indices.unsqueeze(-1), ptrs, idx]) if self.model_type == ASRModelTypeEnum.TDT or self.model_type == ASRModelTypeEnum.RNNT: