From 3211e72408a5bbc08508595ee12f59f1e95ac61d Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 11 Sep 2024 04:08:34 -0700 Subject: [PATCH] If the seq2seq doesn't predict any spaces in the MWT, use the original word to avoid it going crazy --- stanza/models/mwt/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/stanza/models/mwt/trainer.py b/stanza/models/mwt/trainer.py index 37fddd1846..58f01ec10f 100644 --- a/stanza/models/mwt/trainer.py +++ b/stanza/models/mwt/trainer.py @@ -114,7 +114,10 @@ def predict(self, batch, unsort=True, never_decode_unk=False, vocab=None): # if any tokens are predicted to expand to blank, # that is likely an error. use the original text # this originally came up with the Spanish model turning 's' into a blank - pred_tokens = [x if x else y for x, y in zip(pred_tokens, orig_text)] + # furthermore, if there are no spaces predicted by the seq2seq, + # might as well use the original in case the seq2seq went crazy + # this particular error came up training a Hebrew MWT + pred_tokens = [x if x and ' ' in x else y for x, y in zip(pred_tokens, orig_text)] if unsort: pred_tokens = utils.unsort(pred_tokens, orig_idx) return pred_tokens