19
19
20
20
21
21
def greedy (src_mask : Tensor , max_output_length : int , model : Model ,
22
- encoder_output : Tensor , encoder_hidden : Tensor ) \
23
- -> Tuple [np .array , np .array ]:
22
+ encoder_output : Tensor , encoder_hidden : Tensor ,
23
+ generate_unk : bool = False ) -> Tuple [np .array , np .array ]:
24
24
"""
25
25
Greedy decoding. Select the token word highest probability at each time
26
26
step. This function is a wrapper that calls recurrent_greedy for
@@ -31,7 +31,11 @@ def greedy(src_mask: Tensor, max_output_length: int, model: Model,
31
31
:param model: model to use for greedy decoding
32
32
:param encoder_output: encoder hidden states for attention
33
33
:param encoder_hidden: encoder last state for decoder initialization
34
+ :param generate_unk: whether to generate UNK token. if folse,
35
+ the probability of UNK token will artificially be set to zero.
34
36
:return:
37
+ - stacked_output: output hypotheses (2d array of indices),
38
+ - stacked_attention_scores: attention scores (3d array)
35
39
"""
36
40
# pylint: disable=no-else-return
37
41
if isinstance (model .decoder , TransformerDecoder ):
@@ -47,7 +51,8 @@ def greedy(src_mask: Tensor, max_output_length: int, model: Model,
47
51
48
52
49
53
def recurrent_greedy (src_mask : Tensor , max_output_length : int , model : Model ,
50
- encoder_output : Tensor , encoder_hidden : Tensor ) \
54
+ encoder_output : Tensor , encoder_hidden : Tensor ,
55
+ generate_unk : bool = False ) \
51
56
-> Tuple [np .ndarray , Optional [np .ndarray ]]:
52
57
"""
53
58
Greedy decoding: in each step, choose the word that gets highest score.
@@ -58,12 +63,15 @@ def recurrent_greedy(src_mask: Tensor, max_output_length: int, model: Model,
58
63
:param model: model to use for greedy decoding
59
64
:param encoder_output: encoder hidden states for attention
60
65
:param encoder_hidden: encoder last state for decoder initialization
66
+ :param generate_unk: whether to generate UNK token. if folse,
67
+ the probability of UNK token will artificially be set to zero.
61
68
:return:
62
69
- stacked_output: output hypotheses (2d array of indices),
63
70
- stacked_attention_scores: attention scores (3d array)
64
71
"""
65
72
bos_index = model .bos_index
66
73
eos_index = model .eos_index
74
+ unk_index = model .unk_index
67
75
batch_size = src_mask .size (0 )
68
76
prev_y = src_mask .new_full (size = [batch_size , 1 ], fill_value = bos_index ,
69
77
dtype = torch .long )
@@ -88,6 +96,8 @@ def recurrent_greedy(src_mask: Tensor, max_output_length: int, model: Model,
88
96
# logits: batch x time=1 x vocab (logits)
89
97
90
98
# greedy decoding: choose arg max over vocabulary in each step
99
+ if not generate_unk :
100
+ logits [:, :, unk_index ] = float ("-inf" )
91
101
next_word = torch .argmax (logits , dim = - 1 ) # batch x time=1
92
102
output .append (next_word .squeeze (1 ).detach ().cpu ().numpy ())
93
103
prev_y = next_word
@@ -107,7 +117,8 @@ def recurrent_greedy(src_mask: Tensor, max_output_length: int, model: Model,
107
117
108
118
109
119
def transformer_greedy (src_mask : Tensor , max_output_length : int , model : Model ,
110
- encoder_output : Tensor , encoder_hidden : Tensor ) \
120
+ encoder_output : Tensor , encoder_hidden : Tensor ,
121
+ generate_unk : bool = False ) \
111
122
-> Tuple [np .ndarray , Optional [np .ndarray ]]:
112
123
"""
113
124
Special greedy function for transformer, since it works differently.
@@ -118,13 +129,16 @@ def transformer_greedy(src_mask: Tensor, max_output_length: int, model: Model,
118
129
:param model: model to use for greedy decoding
119
130
:param encoder_output: encoder hidden states for attention
120
131
:param encoder_hidden: encoder final state (unused in Transformer)
132
+ :param generate_unk: whether to generate UNK token. if folse,
133
+ the probability of UNK token will artificially be set to zero.
121
134
:return:
122
135
- stacked_output: output hypotheses (2d array of indices),
123
136
- stacked_attention_scores: attention scores (3d array)
124
137
"""
125
138
# pylint: disable=unused-argument
126
139
bos_index = model .bos_index
127
140
eos_index = model .eos_index
141
+ unk_index = model .unk_index
128
142
batch_size = src_mask .size (0 )
129
143
130
144
# start with BOS-symbol for each sentence in the batch
@@ -152,6 +166,8 @@ def transformer_greedy(src_mask: Tensor, max_output_length: int, model: Model,
152
166
trg_mask = trg_mask
153
167
)
154
168
logits = nll_logits [:, - 1 ]
169
+ if not generate_unk :
170
+ logits [:, unk_index ] = float ("-inf" )
155
171
_ , next_word = torch .max (logits , dim = 1 )
156
172
next_word = next_word .data
157
173
ys = torch .cat ([ys , next_word .unsqueeze (- 1 )], dim = 1 )
@@ -169,8 +185,8 @@ def transformer_greedy(src_mask: Tensor, max_output_length: int, model: Model,
169
185
170
186
def beam_search (model : Model , size : int , encoder_output : Tensor ,
171
187
encoder_hidden : Tensor , src_mask : Tensor ,
172
- max_output_length : int , alpha : float , n_best : int = 1 ) \
173
- -> Tuple [np .ndarray , Optional [np .ndarray ]]:
188
+ max_output_length : int , alpha : float , n_best : int = 1 ,
189
+ generate_unk = False ) -> Tuple [np .ndarray , Optional [np .ndarray ]]:
174
190
"""
175
191
Beam search with size k.
176
192
Inspired by OpenNMT-py, adapted for Transformer.
@@ -183,6 +199,8 @@ def beam_search(model: Model, size: int, encoder_output: Tensor,
183
199
:param max_output_length:
184
200
:param alpha: `alpha` factor for length penalty
185
201
:param n_best: return this many hypotheses, <= beam (currently only 1)
202
+ :param generate_unk: whether to generate UNK token. if folse,
203
+ the probability of UNK token will artificially be set to zero.
186
204
:return:
187
205
- stacked_output: output hypotheses (2d array of indices),
188
206
- stacked_attention_scores: attention scores (3d array)
@@ -195,6 +213,7 @@ def beam_search(model: Model, size: int, encoder_output: Tensor,
195
213
bos_index = model .bos_index
196
214
eos_index = model .eos_index
197
215
pad_index = model .pad_index
216
+ unk_index = model .unk_index
198
217
trg_vocab_size = model .decoder .output_size
199
218
device = encoder_output .device
200
219
transformer = isinstance (model .decoder , TransformerDecoder )
@@ -316,6 +335,8 @@ def beam_search(model: Model, size: int, encoder_output: Tensor,
316
335
317
336
# batch*k x trg_vocab
318
337
log_probs = F .log_softmax (logits , dim = - 1 ).squeeze (1 )
338
+ if not generate_unk :
339
+ log_probs [:, unk_index ] = float ("-inf" )
319
340
320
341
# multiply probs by the beam probability (=add logprobs)
321
342
log_probs += topk_log_probs .view (- 1 ).unsqueeze (1 )
@@ -439,7 +460,8 @@ def pad_and_stack_hyps(hyps, pad_value):
439
460
440
461
441
462
def run_batch (model : Model , batch : Batch , max_output_length : int ,
442
- beam_size : int , beam_alpha : float , n_best : int = 1 ) \
463
+ beam_size : int , beam_alpha : float , n_best : int = 1 ,
464
+ generate_unk : bool = False ) \
443
465
-> Tuple [np .ndarray , Optional [np .ndarray ]]:
444
466
"""
445
467
Get outputs and attentions scores for a given batch
@@ -475,7 +497,8 @@ def run_batch(model: Model, batch: Batch, max_output_length: int,
475
497
max_output_length = max_output_length ,
476
498
model = model ,
477
499
encoder_output = encoder_output ,
478
- encoder_hidden = encoder_hidden )
500
+ encoder_hidden = encoder_hidden ,
501
+ generate_unk = generate_unk )
479
502
# batch, time, max_src_length
480
503
else : # beam search
481
504
stacked_output , stacked_attention_scores = beam_search (
@@ -486,6 +509,7 @@ def run_batch(model: Model, batch: Batch, max_output_length: int,
486
509
src_mask = src_mask ,
487
510
max_output_length = max_output_length ,
488
511
alpha = beam_alpha ,
489
- n_best = n_best )
512
+ n_best = n_best ,
513
+ generate_unk = generate_unk )
490
514
491
515
return stacked_output , stacked_attention_scores
0 commit comments