-
Notifications
You must be signed in to change notification settings - Fork 225
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Speech to text translation utilizing 3-way data (#1099)
- Loading branch information
1 parent
aa073f6
commit c80fc07
Showing
6 changed files
with
626 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import Optional, Sequence, Union | ||
|
||
import click | ||
|
||
from lhotse.bin.modes import prepare | ||
from lhotse.recipes.iwslt22_ta import prepare_iwslt22_ta | ||
from lhotse.utils import Pathlike | ||
|
||
|
||
@prepare.command(context_settings=dict(show_default=True)) | ||
@click.argument("corpus_dir", type=click.Path(exists=True, dir_okay=True)) | ||
@click.argument("splits", type=click.Path(exists=True, dir_okay=True)) | ||
@click.argument("output_dir", type=click.Path()) | ||
@click.option( | ||
"-j", | ||
"--num-jobs", | ||
type=int, | ||
default=1, | ||
help="How many threads to use (can give good speed-ups with slow disks).", | ||
) | ||
@click.option( | ||
"--normalize-text", | ||
default=False, | ||
help="Whether to perform additional text cleaning and normalization from https://aclanthology.org/2022.iwslt-1.29.pdf.", | ||
) | ||
@click.option( | ||
"--langs", | ||
default="", | ||
help="Comma-separated list of language abbreviations for source and target languages", | ||
) | ||
def iwslt22_ta( | ||
corpus_dir: Pathlike, | ||
splits: Pathlike, | ||
output_dir: Pathlike, | ||
normalize_text: bool, | ||
langs: str, | ||
num_jobs: int, | ||
): | ||
""" | ||
IWSLT_2022 data preparation. | ||
\b | ||
This is conversational telephone speech collected as 8kHz-sampled data. | ||
The catalog number LDC2022E01 corresponds to the train, dev, and test1 | ||
splits of the iwslt2022 shared task. | ||
To obtaining this data your institution needs to have an LDC subscription. | ||
You also should download the predined splits with | ||
git clone https://github.com/kevinduh/iwslt22-dialect.git | ||
""" | ||
langs_list = langs.split(",") | ||
prepare_iwslt22_ta( | ||
corpus_dir, | ||
splits, | ||
output_dir=output_dir, | ||
num_jobs=num_jobs, | ||
clean=normalize_text, | ||
langs=langs_list, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
# Copyright 2023 Johns Hopkins (authors: Amir Hussein) | ||
|
||
from typing import Callable, Dict, List, Union | ||
|
||
import torch | ||
from torch.utils.data.dataloader import default_collate | ||
|
||
from lhotse.cut import CutSet | ||
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures | ||
from lhotse.dataset.speech_recognition import validate_for_asr | ||
from lhotse.utils import compute_num_frames, ifnone | ||
from lhotse.workarounds import Hdf5MemoryIssueFix | ||
|
||
|
||
class K2Speech2TextTranslationDataset(torch.utils.data.Dataset): | ||
""" | ||
The PyTorch Dataset for the speech translation task using k2 library. | ||
This dataset expects to be queried with lists of cut IDs, | ||
for which it loads features and automatically collates/batches them. | ||
To use it with a PyTorch DataLoader, set ``batch_size=None`` | ||
and provide a :class:`SimpleCutSampler` sampler. | ||
Each item in this dataset is a dict of: | ||
.. code-block:: | ||
{ | ||
'inputs': float tensor with shape determined by :attr:`input_strategy`: | ||
- single-channel: | ||
- features: (B, T, F) | ||
- audio: (B, T) | ||
- multi-channel: currently not supported | ||
'supervisions': [ | ||
{ | ||
'sequence_idx': Tensor[int] of shape (S,) | ||
'src_text': List[str] of len S | ||
'tgt_text': List[str] of len S | ||
# For feature input strategies | ||
'start_frame': Tensor[int] of shape (S,) | ||
'num_frames': Tensor[int] of shape (S,) | ||
# For audio input strategies | ||
'start_sample': Tensor[int] of shape (S,) | ||
'num_samples': Tensor[int] of shape (S,) | ||
# Optionally, when return_cuts=True | ||
'cut': List[AnyCut] of len S | ||
} | ||
] | ||
} | ||
Dimension symbols legend: | ||
* ``B`` - batch size (number of Cuts) | ||
* ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) | ||
* ``T`` - number of frames of the longest Cut | ||
* ``F`` - number of features | ||
The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
return_cuts: bool = False, | ||
cut_transforms: List[Callable[[CutSet], CutSet]] = None, | ||
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, | ||
input_strategy: BatchIO = PrecomputedFeatures(), | ||
): | ||
""" | ||
K2 Speech2TextTranslation IterableDataset constructor. | ||
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut | ||
objects used to create that batch. | ||
:param cut_transforms: A list of transforms to be applied on each sampled batch, | ||
before converting cuts to an input representation (audio/features). | ||
Examples: cut concatenation, noise cuts mixing, etc. | ||
:param input_transforms: A list of transforms to be applied on each sampled batch, | ||
after the cuts are converted to audio/features. | ||
Examples: normalization, SpecAugment, etc. | ||
:param input_strategy: Converts cuts into a collated batch of audio/features. | ||
By default, reads pre-computed features from disk. | ||
""" | ||
super().__init__() | ||
# Initialize the fields | ||
self.return_cuts = return_cuts | ||
self.cut_transforms = ifnone(cut_transforms, []) | ||
self.input_transforms = ifnone(input_transforms, []) | ||
self.input_strategy = input_strategy | ||
|
||
# This attribute is a workaround to constantly growing HDF5 memory | ||
# throughout the epoch. It regularly closes open file handles to | ||
# reset the internal HDF5 caches. | ||
self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100) | ||
|
||
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: | ||
""" | ||
Return a new batch, with the batch size automatically determined using the constraints | ||
of max_frames and max_cuts. | ||
""" | ||
validate_for_asr(cuts) | ||
self.hdf5_fix.update() | ||
|
||
# Sort the cuts by duration so that the first one determines the batch time dimensions. | ||
cuts = cuts.sort_by_duration(ascending=False) | ||
|
||
# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts | ||
# the supervision boundaries. | ||
for tnfm in self.cut_transforms: | ||
cuts = tnfm(cuts) | ||
|
||
# Sort the cuts again after transforms | ||
cuts = cuts.sort_by_duration(ascending=False) | ||
|
||
# Get a tensor with batched feature matrices, shape (B, T, F) | ||
# Collation performs auto-padding, if necessary. | ||
input_tpl = self.input_strategy(cuts) | ||
if len(input_tpl) == 3: | ||
# An input strategy with fault tolerant audio reading mode. | ||
# "cuts" may be a subset of the original "cuts" variable, | ||
# that only has cuts for which we succesfully read the audio. | ||
inputs, _, cuts = input_tpl | ||
else: | ||
inputs, _ = input_tpl | ||
|
||
# Get a dict of tensors that encode the positional information about supervisions | ||
# in the batch of feature matrices. The tensors are named "sequence_idx", | ||
# "start_frame/sample" and "num_frames/samples". | ||
supervision_intervals = self.input_strategy.supervision_intervals(cuts) | ||
|
||
# Apply all available transforms on the inputs, i.e. either audio or features. | ||
# This could be feature extraction, global MVN, SpecAugment, etc. | ||
segments = torch.stack(list(supervision_intervals.values()), dim=1) | ||
for tnfm in self.input_transforms: | ||
inputs = tnfm(inputs, supervision_segments=segments) | ||
batch = { | ||
"inputs": inputs, | ||
"supervisions": default_collate( | ||
[ | ||
{ | ||
"text": supervision.text, | ||
"tgt_text": supervision.custom["translated_text"], | ||
} | ||
for sequence_idx, cut in enumerate(cuts) | ||
for supervision in cut.supervisions | ||
] | ||
), | ||
} | ||
# Update the 'supervisions' field with sequence_idx and start/num frames/samples | ||
batch["supervisions"].update(supervision_intervals) | ||
if self.return_cuts: | ||
batch["supervisions"]["cut"] = [ | ||
cut for cut in cuts for sup in cut.supervisions | ||
] | ||
|
||
has_word_alignments = all( | ||
s.alignment is not None and "word" in s.alignment | ||
for c in cuts | ||
for s in c.supervisions | ||
) | ||
if has_word_alignments: | ||
# TODO: might need to refactor BatchIO API to move the following conditional logic | ||
# into these objects (e.g. use like: self.input_strategy.convert_timestamp(), | ||
# that returns either num_frames or num_samples depending on the strategy). | ||
words, starts, ends = [], [], [] | ||
frame_shift = cuts[0].frame_shift | ||
sampling_rate = cuts[0].sampling_rate | ||
if frame_shift is None: | ||
try: | ||
frame_shift = self.input_strategy.extractor.frame_shift | ||
except AttributeError: | ||
raise ValueError( | ||
"Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " | ||
) | ||
for c in cuts: | ||
for s in c.supervisions: | ||
words.append([aliword.symbol for aliword in s.alignment["word"]]) | ||
starts.append( | ||
[ | ||
compute_num_frames( | ||
aliword.start, | ||
frame_shift=frame_shift, | ||
sampling_rate=sampling_rate, | ||
) | ||
for aliword in s.alignment["word"] | ||
] | ||
) | ||
ends.append( | ||
[ | ||
compute_num_frames( | ||
aliword.end, | ||
frame_shift=frame_shift, | ||
sampling_rate=sampling_rate, | ||
) | ||
for aliword in s.alignment["word"] | ||
] | ||
) | ||
batch["supervisions"]["word"] = words | ||
batch["supervisions"]["word_start"] = starts | ||
batch["supervisions"]["word_end"] = ends | ||
|
||
return batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.