Skip to content

Commit 2e65c81

Browse files
committed
force not to decode unk token
1 parent 5748520 commit 2e65c81

10 files changed

+161
-6095
lines changed

joeynmt/helpers_for_audio.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""
55

66
import io
7-
import os
87
from pathlib import Path
98
import sys
109
from typing import List, Optional, Tuple, Union
@@ -24,7 +23,7 @@
2423
if unicodedata.category(chr(i)).startswith('P')}
2524
def remove_punc(sent: str) -> str:
2625
"""Remove punctuation based on Unicode category.
27-
Note: punctuations in audio transcription are often removed.
26+
Note: punctuations in audio transcription are sometimes removed.
2827
2928
:param sent: sentence string
3029
"""
@@ -36,7 +35,7 @@ def __init__(self, fbank_path: str, n_frames: int, idx: Union[int, str]):
3635
"""Speech Instance
3736
3837
:param fbank_path: (str) Feature file path in the format either of
39-
"<zip path>:<byte offset>:<byte length>" or "<file name>.mp3"
38+
"<zip path>:<byte offset>:<byte length>" or "<file name>.{mp3|wav}"
4039
:param n_frames: (int) number of frames
4140
:param idx: index
4241
"""
@@ -69,17 +68,11 @@ def _get_torchaudio_fbank(waveform: torch.FloatTensor, sample_rate: int,
6968
# from fairseq
7069
def extract_fbank_features(waveform: torch.FloatTensor,
7170
sample_rate: int,
72-
n_frames: int,
73-
utt_id: str,
74-
feature_root: Optional[Path] = None,
71+
output_path: Optional[Path] = None,
7572
n_mel_bins: int = 80,
7673
overwrite: bool = False) -> Optional[np.ndarray]:
7774
# pylint: disable=inconsistent-return-statements
7875

79-
output_path = None
80-
if feature_root is not None:
81-
output_path = feature_root / f"{utt_id}.npy"
82-
8376
if output_path is not None and output_path.is_file() and not overwrite:
8477
return
8578

@@ -88,10 +81,9 @@ def extract_fbank_features(waveform: torch.FloatTensor,
8881

8982
try:
9083
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
91-
assert abs(features.shape[0] - n_frames) <= 1, (n_frames, features.shape)
9284
except Exception as e:
9385
raise ValueError(f"torchaudio faild to extract mel filterbank features "
94-
f"at utt_id: {utt_id}. {e}")
86+
f"at: {output_path.stem}. {e}")
9587

9688
if output_path is not None:
9789
np.save(output_path.as_posix(), features)
@@ -137,8 +129,7 @@ def get_features(root_path: Path, fbank_path: str) -> np.ndarray:
137129
features = np.load(_path.as_posix())
138130
elif _path.suffix in [".mp3", ".wav"]:
139131
waveform, sample_rate = torchaudio.load(_path.as_posix())
140-
num_frames = get_n_frames(waveform.size(1), sample_rate)
141-
features = extract_fbank_features(waveform, sample_rate, num_frames)
132+
features = extract_fbank_features(waveform, sample_rate, utt_id=None)
142133
else:
143134
raise ValueError(f"Invalid file type: {_path}")
144135
elif len(extra) == 2:

joeynmt/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(self,
5757
self.pad_index = self.trg_vocab.pad_index
5858
self.bos_index = self.trg_vocab.bos_index
5959
self.eos_index = self.trg_vocab.eos_index
60+
self.unk_index = self.trg_vocab.unk_index
6061
self._loss_function = None # set by the TrainManager
6162

6263
@property

joeynmt/prediction.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ def validate_on_data(model: Model,
119119
return_type="loss", **vars(batch))
120120
if n_gpu > 1:
121121
batch_loss = batch_loss.sum() # sum on multi-gpu
122-
nll_loss = nll_loss.sum()
123-
ctc_loss = ctc_loss.sum()
122+
nll_loss = nll_loss.sum() if torch.is_tensor(nll_loss) else None
123+
ctc_loss = ctc_loss.sum() if torch.is_tensor(nll_loss) else None
124124
n_correct = n_correct.float().sum()
125125
total_loss['loss'] += batch_loss.item() # float
126126
if torch.is_tensor(nll_loss): # nll_loss is not None
@@ -135,7 +135,7 @@ def validate_on_data(model: Model,
135135
output, attention_scores = run_batch(
136136
model=model, batch=batch, beam_size=beam_size,
137137
beam_alpha=beam_alpha, max_output_length=max_output_length,
138-
n_best=n_best)
138+
n_best=n_best, generate_unk=False)
139139

140140
# sort outputs back to original order
141141
all_outputs.extend(output[sort_reverse_index])
@@ -146,8 +146,10 @@ def validate_on_data(model: Model,
146146
if compute_loss and total_ntokens > 0:
147147
total_normalizer = 1 if total_normalizer == 0 else total_normalizer
148148
valid_scores['loss'] = total_loss['loss'] / total_normalizer
149-
valid_scores['nll_loss'] = total_loss['nll_loss'] / total_normalizer
150-
valid_scores['ctc_loss'] = total_loss['ctc_loss'] / total_normalizer
149+
if 'nll_loss' in total_loss:
150+
valid_scores['nll_loss'] = total_loss['nll_loss'] / total_normalizer
151+
if 'ctc_loss' in total_loss:
152+
valid_scores['ctc_loss'] = total_loss['ctc_loss'] / total_normalizer
151153
# accuracy before decoding
152154
valid_scores['acc'] = total_n_correct / total_ntokens
153155
# exponent of token-level negative log prob

joeynmt/search.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020

2121
def greedy(src_mask: Tensor, max_output_length: int, model: Model,
22-
encoder_output: Tensor, encoder_hidden: Tensor) \
23-
-> Tuple[np.array, np.array]:
22+
encoder_output: Tensor, encoder_hidden: Tensor,
23+
generate_unk: bool = False) -> Tuple[np.array, np.array]:
2424
"""
2525
Greedy decoding. Select the token word highest probability at each time
2626
step. This function is a wrapper that calls recurrent_greedy for
@@ -31,7 +31,11 @@ def greedy(src_mask: Tensor, max_output_length: int, model: Model,
3131
:param model: model to use for greedy decoding
3232
:param encoder_output: encoder hidden states for attention
3333
:param encoder_hidden: encoder last state for decoder initialization
34+
:param generate_unk: whether to generate UNK token. if folse,
35+
the probability of UNK token will artificially be set to zero.
3436
:return:
37+
- stacked_output: output hypotheses (2d array of indices),
38+
- stacked_attention_scores: attention scores (3d array)
3539
"""
3640
# pylint: disable=no-else-return
3741
if isinstance(model.decoder, TransformerDecoder):
@@ -47,7 +51,8 @@ def greedy(src_mask: Tensor, max_output_length: int, model: Model,
4751

4852

4953
def recurrent_greedy(src_mask: Tensor, max_output_length: int, model: Model,
50-
encoder_output: Tensor, encoder_hidden: Tensor) \
54+
encoder_output: Tensor, encoder_hidden: Tensor,
55+
generate_unk: bool = False) \
5156
-> Tuple[np.ndarray, Optional[np.ndarray]]:
5257
"""
5358
Greedy decoding: in each step, choose the word that gets highest score.
@@ -58,12 +63,15 @@ def recurrent_greedy(src_mask: Tensor, max_output_length: int, model: Model,
5863
:param model: model to use for greedy decoding
5964
:param encoder_output: encoder hidden states for attention
6065
:param encoder_hidden: encoder last state for decoder initialization
66+
:param generate_unk: whether to generate UNK token. if folse,
67+
the probability of UNK token will artificially be set to zero.
6168
:return:
6269
- stacked_output: output hypotheses (2d array of indices),
6370
- stacked_attention_scores: attention scores (3d array)
6471
"""
6572
bos_index = model.bos_index
6673
eos_index = model.eos_index
74+
unk_index = model.unk_index
6775
batch_size = src_mask.size(0)
6876
prev_y = src_mask.new_full(size=[batch_size, 1], fill_value=bos_index,
6977
dtype=torch.long)
@@ -88,6 +96,8 @@ def recurrent_greedy(src_mask: Tensor, max_output_length: int, model: Model,
8896
# logits: batch x time=1 x vocab (logits)
8997

9098
# greedy decoding: choose arg max over vocabulary in each step
99+
if not generate_unk:
100+
logits[:, :, unk_index] = float("-inf")
91101
next_word = torch.argmax(logits, dim=-1) # batch x time=1
92102
output.append(next_word.squeeze(1).detach().cpu().numpy())
93103
prev_y = next_word
@@ -107,7 +117,8 @@ def recurrent_greedy(src_mask: Tensor, max_output_length: int, model: Model,
107117

108118

109119
def transformer_greedy(src_mask: Tensor, max_output_length: int, model: Model,
110-
encoder_output: Tensor, encoder_hidden: Tensor) \
120+
encoder_output: Tensor, encoder_hidden: Tensor,
121+
generate_unk: bool = False) \
111122
-> Tuple[np.ndarray, Optional[np.ndarray]]:
112123
"""
113124
Special greedy function for transformer, since it works differently.
@@ -118,13 +129,16 @@ def transformer_greedy(src_mask: Tensor, max_output_length: int, model: Model,
118129
:param model: model to use for greedy decoding
119130
:param encoder_output: encoder hidden states for attention
120131
:param encoder_hidden: encoder final state (unused in Transformer)
132+
:param generate_unk: whether to generate UNK token. if folse,
133+
the probability of UNK token will artificially be set to zero.
121134
:return:
122135
- stacked_output: output hypotheses (2d array of indices),
123136
- stacked_attention_scores: attention scores (3d array)
124137
"""
125138
# pylint: disable=unused-argument
126139
bos_index = model.bos_index
127140
eos_index = model.eos_index
141+
unk_index = model.unk_index
128142
batch_size = src_mask.size(0)
129143

130144
# start with BOS-symbol for each sentence in the batch
@@ -152,6 +166,8 @@ def transformer_greedy(src_mask: Tensor, max_output_length: int, model: Model,
152166
trg_mask=trg_mask
153167
)
154168
logits = nll_logits[:, -1]
169+
if not generate_unk:
170+
logits[:, unk_index] = float("-inf")
155171
_, next_word = torch.max(logits, dim=1)
156172
next_word = next_word.data
157173
ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1)
@@ -169,8 +185,8 @@ def transformer_greedy(src_mask: Tensor, max_output_length: int, model: Model,
169185

170186
def beam_search(model: Model, size: int, encoder_output: Tensor,
171187
encoder_hidden: Tensor, src_mask: Tensor,
172-
max_output_length: int, alpha: float, n_best: int = 1) \
173-
-> Tuple[np.ndarray, Optional[np.ndarray]]:
188+
max_output_length: int, alpha: float, n_best: int = 1,
189+
generate_unk = False) -> Tuple[np.ndarray, Optional[np.ndarray]]:
174190
"""
175191
Beam search with size k.
176192
Inspired by OpenNMT-py, adapted for Transformer.
@@ -183,6 +199,8 @@ def beam_search(model: Model, size: int, encoder_output: Tensor,
183199
:param max_output_length:
184200
:param alpha: `alpha` factor for length penalty
185201
:param n_best: return this many hypotheses, <= beam (currently only 1)
202+
:param generate_unk: whether to generate UNK token. if folse,
203+
the probability of UNK token will artificially be set to zero.
186204
:return:
187205
- stacked_output: output hypotheses (2d array of indices),
188206
- stacked_attention_scores: attention scores (3d array)
@@ -195,6 +213,7 @@ def beam_search(model: Model, size: int, encoder_output: Tensor,
195213
bos_index = model.bos_index
196214
eos_index = model.eos_index
197215
pad_index = model.pad_index
216+
unk_index = model.unk_index
198217
trg_vocab_size = model.decoder.output_size
199218
device = encoder_output.device
200219
transformer = isinstance(model.decoder, TransformerDecoder)
@@ -316,6 +335,8 @@ def beam_search(model: Model, size: int, encoder_output: Tensor,
316335

317336
# batch*k x trg_vocab
318337
log_probs = F.log_softmax(logits, dim=-1).squeeze(1)
338+
if not generate_unk:
339+
log_probs[:, unk_index] = float("-inf")
319340

320341
# multiply probs by the beam probability (=add logprobs)
321342
log_probs += topk_log_probs.view(-1).unsqueeze(1)
@@ -439,7 +460,8 @@ def pad_and_stack_hyps(hyps, pad_value):
439460

440461

441462
def run_batch(model: Model, batch: Batch, max_output_length: int,
442-
beam_size: int, beam_alpha: float, n_best: int = 1) \
463+
beam_size: int, beam_alpha: float, n_best: int = 1,
464+
generate_unk: bool = False) \
443465
-> Tuple[np.ndarray, Optional[np.ndarray]]:
444466
"""
445467
Get outputs and attentions scores for a given batch
@@ -475,7 +497,8 @@ def run_batch(model: Model, batch: Batch, max_output_length: int,
475497
max_output_length=max_output_length,
476498
model=model,
477499
encoder_output=encoder_output,
478-
encoder_hidden=encoder_hidden)
500+
encoder_hidden=encoder_hidden,
501+
generate_unk=generate_unk)
479502
# batch, time, max_src_length
480503
else: # beam search
481504
stacked_output, stacked_attention_scores = beam_search(
@@ -486,6 +509,7 @@ def run_batch(model: Model, batch: Batch, max_output_length: int,
486509
src_mask=src_mask,
487510
max_output_length=max_output_length,
488511
alpha=beam_alpha,
489-
n_best=n_best)
512+
n_best=n_best,
513+
generate_unk=generate_unk)
490514

491515
return stacked_output, stacked_attention_scores

joeynmt/training.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -819,10 +819,16 @@ def train(cfg_file: str, skip_test: bool = False) -> None:
819819
src_vocab, trg_vocab, train_data, dev_data, test_data = load_data(
820820
data_cfg=cfg["data"])
821821

822-
# store the vocabs
822+
# store the vocabs and tokenizers
823823
if task == "MT":
824824
src_vocab.to_file(model_dir / "src_vocab.txt")
825+
if "model_file" in cfg["data"]["src"]["spm"]:
826+
src_tok = Path(cfg["data"]["src"]["spm"]["model_file"])
827+
shutil.copy2(src_tok, (model_dir / src_tok.name).as_posix())
825828
trg_vocab.to_file(model_dir / "trg_vocab.txt")
829+
if "model_file" in cfg["data"]["trg"]["spm"]:
830+
trg_tok = Path(cfg["data"]["trg"]["spm"]["model_file"])
831+
shutil.copy2(trg_tok, (model_dir / trg_tok.name).as_posix())
826832

827833
# build an encoder-decoder model
828834
model = build_model(cfg["model"], src_vocab=src_vocab, trg_vocab=trg_vocab)

joeynmt/vocabulary.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@ def __init__(self, tokens: List[str]) -> None:
4444
self.pad_index = self.lookup(PAD_TOKEN)
4545
self.bos_index = self.lookup(BOS_TOKEN)
4646
self.eos_index = self.lookup(EOS_TOKEN)
47+
self.unk_index = self.lookup(UNK_TOKEN)
4748
assert self.pad_index == PAD_ID
4849
assert self.bos_index == BOS_ID
4950
assert self.eos_index == EOS_ID
51+
assert self.unk_index == UNK_ID
5052
assert self._itos[UNK_ID] == UNK_TOKEN
5153

5254
def __str__(self) -> str:

scripts/audiodata_utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,20 @@ def get_zip_manifest(zip_path: Path, npy_root: Optional[Path] = None):
2727
manifest = {}
2828
with zipfile.ZipFile(zip_path, mode="r") as f:
2929
info = f.infolist()
30-
error_flag = []
30+
# retrieve offsets
3131
for i in tqdm(info):
3232
utt_id = Path(i.filename).stem
3333
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
3434
with zip_path.open("rb") as f:
3535
f.seek(offset)
3636
data = f.read(file_size)
37-
try:
38-
assert len(data) > 1 and _is_npy_data(data), (utt_id, len(data), e)
39-
except Exception as e:
40-
print((utt_id, len(data), e))
41-
error_flag.append((utt_id, len(data)))
37+
assert len(data) > 1 and _is_npy_data(data), (utt_id, len(data))
4238
manifest[utt_id] = f"{zip_path.name}:{offset}:{file_size}"
4339
# sanity check
4440
if npy_root is not None:
4541
byte_data = np.load(io.BytesIO(data))
4642
npy_data = np.load((npy_root / f"{utt_id}.npy").as_posix())
4743
assert np.allclose(byte_data, npy_data)
48-
if len(error_flag) > 0:
49-
print(error_flag)
50-
raise Exception
5144
return manifest
5245

5346

0 commit comments

Comments
 (0)