Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 60 additions & 14 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ def _compute_offsets(hypothesis: Hypothesis, blank_id: int) -> List[Dict[str, Un

Returns:
List[Dict[str, Union[str, int]]]: A list of dictionaries, where each dictionary contains:
- "char": List[str] - The character/subword token
- "char": List[int] - The character/subword token ID
- "start_offset": int - The start time index of the token
- "end_offset": int - The end time index of the token

Expand All @@ -1127,32 +1127,78 @@ def _compute_offsets(hypothesis: Hypothesis, blank_id: int) -> List[Dict[str, Un
@staticmethod
def _compute_offsets_tdt(hypothesis: Hypothesis, blank_id: int, *args) -> List[Dict[str, Union[str, int]]]:
"""
Utility method that calculates the indidual time indices where a token starts and ends.
Utility method that calculates the individual time indices where a token starts and ends for TDT models.

This method handles both START and END timestamp semantics. Beam search in TDT stores END timestamps
(timesteps + duration), while greedy decoding stores START timestamps. The method uses an explicit
_timestamp_semantics flag or a fallback heuristic to determine the correct interpretation.

Args:
hypothesis: A Hypothesis object that contains `text` field that holds the character / subword token
emitted at a specific time step considering predicted durations of the previous tokens.
hypothesis: A Hypothesis object that contains `timestamp` and `token_duration` fields.
blank_id: The index representing the blank token in the vocabulary.
*args: Additional arguments (unused, for compatibility).

Returns:
List[Dict[str, Union[str, int]]]: A list of dictionaries, where each dictionary contains:
- "char": List[str] - The character/subword token
- "char": List[int] - The character/subword token ID
- "start_offset": int - The start time index of the token
- "end_offset": int - The end time index of the token

**Note**: Blank tokens are not included in the offsets.
"""
if isinstance(hypothesis.timestamp, torch.Tensor):
hypothesis.token_duration = hypothesis.token_duration.cpu().tolist()

# Handle missing token_duration (compute from timestamp diffs)
if hypothesis.token_duration is None:
if hypothesis.timestamp is None or len(hypothesis.timestamp) == 0:
return []
if isinstance(hypothesis.timestamp, torch.Tensor):
timestamps = hypothesis.timestamp.cpu().tolist()
else:
timestamps = list(hypothesis.timestamp)
durations = []
prev = 0
for curr in timestamps:
durations.append(max(0, curr - prev))
prev = curr
hypothesis.token_duration = durations

# Convert to lists
if isinstance(hypothesis.timestamp, torch.Tensor):
hypothesis.timestamp = hypothesis.timestamp.cpu().tolist()
if isinstance(hypothesis.token_duration, torch.Tensor):
hypothesis.token_duration = hypothesis.token_duration.cpu().tolist()

# Merge the results per token into a list of dictionaries
offsets = [
{"char": [t], "start_offset": s, "end_offset": s + d}
for t, s, d in zip(hypothesis.y_sequence, hypothesis.timestamp, hypothesis.token_duration)
if t != blank_id
]
# Determine timestamp semantics
semantics_flag = getattr(hypothesis, "_timestamp_semantics", None)
if semantics_flag == "end":
timestamps_are_ends = True
elif semantics_flag == "start":
timestamps_are_ends = False
else:
# Fallback heuristic: END timestamps never precede cumulative duration
timestamps_are_ends = False
if len(hypothesis.timestamp) > 0 and len(hypothesis.token_duration) > 0:
cumulative = 0
matches_end_semantics = True
for ts, dur in zip(hypothesis.timestamp, hypothesis.token_duration):
cumulative += dur
if ts + 1 < cumulative: # 1-frame slack
matches_end_semantics = False
break
if matches_end_semantics and hypothesis.timestamp[0] >= hypothesis.token_duration[0]:
timestamps_are_ends = True

# Compute offsets
offsets = []
for t, ts, d in zip(hypothesis.y_sequence, hypothesis.timestamp, hypothesis.token_duration):
if t == blank_id:
continue
if timestamps_are_ends:
start_offset = ts - d
end_offset = ts
else:
start_offset = ts
end_offset = ts + d
offsets.append({"char": [t], "start_offset": start_offset, "end_offset": end_offset})
return offsets

@staticmethod
Expand Down
94 changes: 70 additions & 24 deletions nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ def __init__(
self.next_timestamp = torch.zeros((batch_size, self.beam_size), device=device, dtype=torch.long)
self.last_timestamp_lasts = torch.zeros((batch_size, self.beam_size), device=device, dtype=torch.long)

# For TDT models, also store token durations to enable correct timestamp offset computation
# Beam search stores END timestamps (timesteps + duration), but offset computation expects START
if self.model_type == ASRModelTypeEnum.TDT:
self.token_durations = torch.zeros(
(batch_size, self.beam_size, self._max_length), device=device, dtype=torch.long
)

def clear_(self):
"""
Clears and resets the internal state of the object.
Expand Down Expand Up @@ -199,6 +206,9 @@ def clear_(self):
self.next_timestamp.fill_(0)
self.last_timestamp_lasts.fill_(0)

if self.model_type == ASRModelTypeEnum.TDT:
self.token_durations.fill_(0)

def _allocate_more(self):
"""
Dynamically allocates more memory for the internal buffers.
Expand All @@ -216,6 +226,9 @@ def _allocate_more(self):
else:
self.timestamps = torch.cat((self.timestamps, torch.zeros_like(self.timestamps)), dim=-1)

if self.model_type == ASRModelTypeEnum.TDT:
self.token_durations = torch.cat((self.token_durations, torch.zeros_like(self.token_durations)), dim=-1)

self._max_length *= 2

def add_results_(
Expand Down Expand Up @@ -301,6 +314,13 @@ def add_results_no_checks_(
index=self.current_lengths_wb.unsqueeze(-1),
src=(timesteps + next_label_durations).unsqueeze(-1),
)
# Also store token durations for TDT (needed for correct offset computation)
durations_to_store = torch.where(is_extended, next_label_durations, 0)
self.token_durations.scatter_(
dim=-1,
index=self.current_lengths_wb.unsqueeze(-1),
src=durations_to_store.unsqueeze(-1),
)
torch.where(is_extended, timesteps + next_label_durations, timesteps, out=self.next_timestamp)
torch.where(
is_extended & (next_label_durations > 0),
Expand Down Expand Up @@ -474,19 +494,28 @@ def to_hyps_list(self, score_norm: bool = True) -> list[Hypothesis]:
max_idx = self.current_lengths_wb.max() - 1
timestamps = self.timestamps[..., 0, : max_idx + 1]
transcripts = self.transcript_wb[..., 0, : max_idx + 1]
hypotheses = [
Hypothesis(

if self.model_type == ASRModelTypeEnum.TDT:
token_durations = self.token_durations[..., 0, : max_idx + 1]

hypotheses = []
for batch_idx in range(self.batch_size):
mask = self._create_transcripts_mask(transcripts[batch_idx])
hyp = Hypothesis(
score=scores[batch_idx],
y_sequence=transcripts[batch_idx][mask := self._create_transcripts_mask(transcripts[batch_idx])]
.cpu()
.detach()
.numpy(),
y_sequence=transcripts[batch_idx][mask].cpu().detach().numpy(),
timestamp=timestamps[batch_idx][mask].cpu().detach().numpy(),
alignments=None,
dec_state=None,
)
for batch_idx in range(self.batch_size)
]

# Populate token_duration and mark semantics for TDT
if self.model_type == ASRModelTypeEnum.TDT:
hyp.token_duration = token_durations[batch_idx][mask].cpu().detach().numpy().tolist()
hyp._timestamp_semantics = "end" # Explicit marker

hypotheses.append(hyp)

return hypotheses

def to_nbest_hyps_list(self, score_norm: bool = True) -> list[NBestHypotheses]:
Expand All @@ -506,27 +535,33 @@ def to_nbest_hyps_list(self, score_norm: bool = True) -> list[NBestHypotheses]:
max_idx = self.current_lengths_wb.max() - 1
transcripts = self.transcript_wb[..., : max_idx + 1]
timestamps = self.timestamps[..., : max_idx + 1]
hypotheses = [
NBestHypotheses(
[
Hypothesis(

if self.model_type == ASRModelTypeEnum.TDT:
token_durations = self.token_durations[..., : max_idx + 1]

hypotheses = []
for batch_idx in range(self.batch_size):
batch_hyps = []
for beam_idx in range(self.beam_size):
if scores[batch_idx][beam_idx] > INACTIVE_SCORE:
mask = self._create_transcripts_mask(transcripts[batch_idx][beam_idx])
hyp = Hypothesis(
score=scores[batch_idx][beam_idx],
y_sequence=transcripts[batch_idx][beam_idx][
mask := self._create_transcripts_mask(transcripts[batch_idx][beam_idx])
]
.cpu()
.detach()
.numpy(),
y_sequence=transcripts[batch_idx][beam_idx][mask].cpu().detach().numpy(),
timestamp=timestamps[batch_idx][beam_idx][mask].cpu().detach().numpy(),
alignments=None,
dec_state=None,
)
for beam_idx in range(self.beam_size)
if scores[batch_idx][beam_idx] > INACTIVE_SCORE
]
)
for batch_idx in range(self.batch_size)
]

# Populate token_duration and mark semantics for TDT
if self.model_type == ASRModelTypeEnum.TDT:
hyp.token_duration = token_durations[batch_idx][beam_idx][mask].cpu().detach().numpy().tolist()
hyp._timestamp_semantics = "end" # Explicit marker

batch_hyps.append(hyp)

hypotheses.append(NBestHypotheses(batch_hyps))

return hypotheses

def flatten_sort_(self, score_norm: bool = True):
Expand All @@ -552,6 +587,17 @@ def flatten_sort_(self, score_norm: bool = True):
max_idx = self.current_lengths_wb.max() - 1
ptrs = indices

# Sort token_durations BEFORE updating ptrs to avoid misalignment
if self.model_type == ASRModelTypeEnum.TDT:
token_durations_sorted = torch.zeros_like(self.token_durations)
temp_ptrs = indices
for idx in range(max_idx, -1, -1):
token_durations_sorted[..., idx].copy_(
self.token_durations[self.batch_indices.unsqueeze(-1), temp_ptrs, idx]
)
temp_ptrs = self.transcript_wb_prev_ptr[self.batch_indices.unsqueeze(-1), temp_ptrs, idx]
self.token_durations.copy_(token_durations_sorted)

for idx in range(max_idx, -1, -1):
self.transcript_wb[..., idx].copy_(self.transcript_wb[self.batch_indices.unsqueeze(-1), ptrs, idx])
if self.model_type == ASRModelTypeEnum.TDT or self.model_type == ASRModelTypeEnum.RNNT:
Expand Down