diff --git a/examples/simultaneous_translation/README.md b/examples/simultaneous_translation/README.md index bbc6dacdda..62a005e0ec 100644 --- a/examples/simultaneous_translation/README.md +++ b/examples/simultaneous_translation/README.md @@ -1,106 +1,5 @@ -# Simultaneous Machine Translation - -This directory contains the code for the paper [Monotonic Multihead Attention](https://openreview.net/forum?id=Hyg96gBKPS) - -## Prepare Data - -[Please follow the instructions to download and preprocess the WMT'15 En-De dataset.](https://github.com/pytorch/fairseq/tree/simulastsharedtask/examples/translation#prepare-wmt14en2desh) - -## Training - -- MMA-IL - -```shell -fairseq-train \ - data-bin/wmt15_en_de_32k \ - --simul-type infinite_lookback \ - --user-dir $FAIRSEQ/example/simultaneous_translation \ - --mass-preservation \ - --criterion latency_augmented_label_smoothed_cross_entropy \ - --latency-weight-avg 0.1 \ - --max-update 50000 \ - --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \ - --optimizer adam --adam-betas '(0.9, 0.98)' \ - --lr-scheduler 'inverse_sqrt' \ - --warmup-init-lr 1e-7 --warmup-updates 4000 \ - --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ - --dropout 0.3 \ - --label-smoothing 0.1\ - --max-tokens 3584 -``` - -- MMA-H - -```shell -fairseq-train \ - data-bin/wmt15_en_de_32k \ - --simul-type hard_aligned \ - --user-dir $FAIRSEQ/example/simultaneous_translation \ - --mass-preservation \ - --criterion latency_augmented_label_smoothed_cross_entropy \ - --latency-weight-var 0.1 \ - --max-update 50000 \ - --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \ - --optimizer adam --adam-betas '(0.9, 0.98)' \ - --lr-scheduler 'inverse_sqrt' \ - --warmup-init-lr 1e-7 --warmup-updates 4000 \ - --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ - --dropout 0.3 \ - --label-smoothing 0.1\ - --max-tokens 3584 -``` - -- wait-k - -```shell -fairseq-train \ - data-bin/wmt15_en_de_32k \ - --simul-type wait-k \ - --waitk-lagging 3 \ - --user-dir $FAIRSEQ/example/simultaneous_translation \ - --mass-preservation \ - --criterion latency_augmented_label_smoothed_cross_entropy \ - --max-update 50000 \ - --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \ - --optimizer adam --adam-betas '(0.9, 0.98)' \ - --lr-scheduler 'inverse_sqrt' \ - --warmup-init-lr 1e-7 --warmup-updates 4000 \ - --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ - --dropout 0.3 \ - --label-smoothing 0.1\ - --max-tokens 3584 -``` - - -## Evaluation - -More details on evaluation can be found [here](https://github.com/pytorch/fairseq/blob/simulastsharedtask/examples/simultaneous_translation/docs/evaluation.md) - -### Start the server - -```shell -python ./eval/server.py \ - --src-file $SRC_FILE \ - --ref-file $TGT_FILE -``` - -### Run the client - -```shell -python ./evaluate.py \ - --data-bin data-bin/wmt15_en_de_32k \ - --model-path ./checkpoints/checkpoint_best.pt - --scores --output $RESULT_DIR -``` - -### Run evaluation locally without server - -```shell -python ./eval/evaluate.py - --local \ - --src-file $SRC_FILE \ - --tgt-file $TGT_FILE \ - --data-bin data-bin/wmt15_en_de_32k \ - --model-path ./checkpoints/checkpoint_best.pt \ - --scores --output $RESULT_DIR -``` +# Simultaneous Translation +Examples of simultaneous translation in fairseq +- [English-to-Japanese text-to-text wait-k model](docs/enja-waitk.md) +- [English-to-Germen text-to-text monotonic multihead attention model](docs/ende-mma.md) +- [English-to-Germen speech-to-text simultaneous translation model](../speech_to_text/docs/simulst_mustc_example.md) diff --git a/examples/simultaneous_translation/docs/ende-mma.md b/examples/simultaneous_translation/docs/ende-mma.md new file mode 100644 index 0000000000..241d604a3b --- /dev/null +++ b/examples/simultaneous_translation/docs/ende-mma.md @@ -0,0 +1,74 @@ +# Simultaneous Machine Translation + +This directory contains the code for the paper [Monotonic Multihead Attention](https://openreview.net/forum?id=Hyg96gBKPS) + +## Prepare Data + +[Please follow the instructions to download and preprocess the WMT'15 En-De dataset.](https://github.com/pytorch/fairseq/tree/simulastsharedtask/examples/translation#prepare-wmt14en2desh) + +Another example of training an English to Japanese model can be found [here](docs/enja.md) + +## Training + +- MMA-IL + +```shell +fairseq-train \ + data-bin/wmt15_en_de_32k \ + --simul-type infinite_lookback \ + --user-dir $FAIRSEQ/example/simultaneous_translation \ + --mass-preservation \ + --criterion latency_augmented_label_smoothed_cross_entropy \ + --latency-weight-avg 0.1 \ + --max-update 50000 \ + --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \ + --optimizer adam --adam-betas '(0.9, 0.98)' \ + --lr-scheduler 'inverse_sqrt' \ + --warmup-init-lr 1e-7 --warmup-updates 4000 \ + --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ + --dropout 0.3 \ + --label-smoothing 0.1\ + --max-tokens 3584 +``` + +- MMA-H + +```shell +fairseq-train \ + data-bin/wmt15_en_de_32k \ + --simul-type hard_aligned \ + --user-dir $FAIRSEQ/example/simultaneous_translation \ + --mass-preservation \ + --criterion latency_augmented_label_smoothed_cross_entropy \ + --latency-weight-var 0.1 \ + --max-update 50000 \ + --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \ + --optimizer adam --adam-betas '(0.9, 0.98)' \ + --lr-scheduler 'inverse_sqrt' \ + --warmup-init-lr 1e-7 --warmup-updates 4000 \ + --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ + --dropout 0.3 \ + --label-smoothing 0.1\ + --max-tokens 3584 +``` + +- wait-k + +```shell +fairseq-train \ + data-bin/wmt15_en_de_32k \ + --simul-type wait-k \ + --waitk-lagging 3 \ + --user-dir $FAIRSEQ/example/simultaneous_translation \ + --mass-preservation \ + --criterion latency_augmented_label_smoothed_cross_entropy \ + --max-update 50000 \ + --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \ + --optimizer adam --adam-betas '(0.9, 0.98)' \ + --lr-scheduler 'inverse_sqrt' \ + --warmup-init-lr 1e-7 --warmup-updates 4000 \ + --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ + --dropout 0.3 \ + --label-smoothing 0.1\ + --max-tokens 3584 +``` diff --git a/examples/simultaneous_translation/docs/enja-waitk.md b/examples/simultaneous_translation/docs/enja-waitk.md new file mode 100644 index 0000000000..fb9d82576f --- /dev/null +++ b/examples/simultaneous_translation/docs/enja-waitk.md @@ -0,0 +1,106 @@ +# An example of English to Japaneses Simultaneous Translation System + +This is an example of training and evaluating a transformer *wait-k* English to Japanese simultaneous text-to-text translation model. + +## Data Preparation +This section introduces the data preparation for training and evaluation. +If you only want to evaluate the model, please jump to [Inference & Evaluation](#inference-&-evaluation) + +For illustration, we only use the following subsets of the available data from [WMT20 news translation task](http://www.statmt.org/wmt20/translation-task.html), which results in 7,815,391 sentence pairs. +- News Commentary v16 +- Wiki Titles v3 +- WikiMatrix V1 +- Japanese-English Subtitle Corpus +- The Kyoto Free Translation Task Corpus + +We use WMT20 development data as development set. Training `transformer_vaswani_wmt_en_de_big` model on such amount of data will result in 17.3 BLEU with greedy search and 19.7 with beam (10) search. Notice that a better performance can be achieved with the full WMT training data. + +We use [sentencepiece](https://github.com/google/sentencepiece) toolkit to tokenize the data with a vocabulary size of 32000. +Additionally, we filtered out the sentences longer than 200 words after tokenization. +Assuming the tokenized text data is saved at `${DATA_DIR}`, +we prepare the data binary with the following command. + +```bash +fairseq-preprocess \ + --source-lang en --target-lang ja \ + --trainpref ${DATA_DIR}/train \ + --validpref ${DATA_DIR}/dev \ + --testpref ${DATA_DIR}/test \ + --destdir ${WMT20_ENJA_DATA_BIN} \ + --nwordstgt 32000 --nwordssrc 32000 \ + --workers 20 +``` + +## Simultaneous Translation Model Training +To train a wait-k `(k=10)` model. +```bash +fairseq-train ${WMT20_ENJA_DATA_BIN} \ + --save-dir ${SAVEDIR} + --simul-type waitk \ + --waitk-lagging 10 \ + --max-epoch 70 \ + --arch transformer_monotonic_vaswani_wmt_en_de_big \ + --optimizer adam \ + --adam-betas '(0.9, 0.98)' \ + --lr-scheduler inverse_sqrt \ + --warmup-init-lr 1e-07 \ + --warmup-updates 4000 \ + --lr 0.0005 \ + --stop-min-lr 1e-09 \ + --clip-norm 10.0 \ + --dropout 0.3 \ + --weight-decay 0.0 \ + --criterion label_smoothed_cross_entropy \ + --label-smoothing 0.1 \ + --max-tokens 3584 +``` +This command is for training on 8 GPUs. Equivalently, the model can be trained on one GPU with `--update-freq 8`. + +## Inference & Evaluation +First of all, install [SimulEval](https://github.com/facebookresearch/SimulEval) for evaluation. + +```bash +git clone https://github.com/facebookresearch/SimulEval.git +cd SimulEval +pip install -e . +``` + +The following command is for the evaluation. +Assuming the source and reference files are `${SRC_FILE}` and `${REF_FILE}`, the sentencepiece model file for English is saved at `${SRC_SPM_PATH}` + + +```bash +simuleval \ + --source ${SRC_FILE} \ + --target ${TGT_FILE} \ + --data-bin ${WMT20_ENJA_DATA_BIN} \ + --sacrebleu-tokenizer ja-mecab \ + --eval-latency-unit char \ + --no-space \ + --src-splitter-type sentencepiecemodel \ + --src-splitter-path ${SRC_SPM_PATH} \ + --agent ${FAIRSEQ}/examples/simultaneous_translation/agents/simul_trans_text_agent_enja.py \ + --model-path ${SAVE_DIR}/${CHECKPOINT_FILENAME} \ + --output ${OUTPUT} \ + --scores +``` + +The `--data-bin` should be the same in previous sections if you prepare the data from the scratch. +If only for evaluation, a prepared data directory can be found [here](https://dl.fbaipublicfiles.com/simultaneous_translation/wmt20_enja_medium_databin.tgz) and a pretrained checkpoint (wait-k=10 model) can be downloaded from [here](https://dl.fbaipublicfiles.com/simultaneous_translation/wmt20_enja_medium_wait10_ckpt.pt). + +The output should look like this: +```bash +{ + "Quality": { + "BLEU": 11.442253287568398 + }, + "Latency": { + "AL": 8.6587861866951, + "AP": 0.7863304776251316, + "DAL": 9.477850951194764 + } +} +``` +The latency is evaluated by characters (`--eval-latency-unit`) on the target side. The latency is evaluated with `sacrebleu` with `MeCab` tokenizer `--sacrebleu-tokenizer ja-mecab`. `--no-space` indicates that do not add space when merging the predicted words. + +If `--output ${OUTPUT}` option is used, the detailed log and scores will be stored under the `${OUTPUT}` directory. diff --git a/examples/simultaneous_translation/eval/__init__.py b/examples/simultaneous_translation/eval/__init__.py deleted file mode 100644 index 6264236915..0000000000 --- a/examples/simultaneous_translation/eval/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. diff --git a/examples/simultaneous_translation/eval/agents/__init__.py b/examples/simultaneous_translation/eval/agents/__init__.py deleted file mode 100644 index 511e7b2474..0000000000 --- a/examples/simultaneous_translation/eval/agents/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import importlib -import os - -from fairseq import registry - - -build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry( - "--agent-type" -) - - -DEFAULT_EOS = "" -GET = 0 -SEND = 1 - -for file in os.listdir(os.path.dirname(__file__)): - if file.endswith(".py") and not file.startswith("_"): - module = file[: file.find(".py")] - importlib.import_module("agents." + module) diff --git a/examples/simultaneous_translation/eval/agents/agent.py b/examples/simultaneous_translation/eval/agents/agent.py deleted file mode 100644 index 997392cf9b..0000000000 --- a/examples/simultaneous_translation/eval/agents/agent.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import time -from functools import partial -from multiprocessing.pool import ThreadPool as Pool - -from . import DEFAULT_EOS, GET, SEND - - -class Agent(object): - "an agent needs to follow this pattern" - - def __init__(self, *args, **kwargs): - pass - - def init_states(self, *args, **kwargs): - raise NotImplementedError - - def update_states(self, states, new_state): - raise NotImplementedError - - def finish_eval(self, states, new_state): - raise NotImplementedError - - def policy(self, state): - raise NotImplementedError - - def reset(self): - raise NotImplementedError - - def decode(self, session, low=0, high=100000, num_thread=10): - corpus_info = session.corpus_info() - high = min(corpus_info["num_sentences"] - 1, high) - if low >= high: - return - - t0 = time.time() - if num_thread > 1: - with Pool(10) as p: - p.map( - partial(self._decode_one, session), - [sent_id for sent_id in range(low, high + 1)], - ) - else: - for sent_id in range(low, high + 1): - self._decode_one(session, sent_id) - - print(f"Finished {low} to {high} in {time.time() - t0}s") - - def _decode_one(self, session, sent_id): - action = {} - self.reset() - states = self.init_states() - while action.get("value", None) != DEFAULT_EOS: - # take an action - action = self.policy(states) - - if action["key"] == GET: - new_states = session.get_src(sent_id, action["value"]) - states = self.update_states(states, new_states) - - elif action["key"] == SEND: - session.send_hypo(sent_id, action["value"]) - print(" ".join(states["tokens"]["tgt"])) diff --git a/examples/simultaneous_translation/eval/agents/simul_t2t_enja.py b/examples/simultaneous_translation/eval/agents/simul_t2t_enja.py new file mode 100644 index 0000000000..8f3c8703ca --- /dev/null +++ b/examples/simultaneous_translation/eval/agents/simul_t2t_enja.py @@ -0,0 +1,226 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os + +from fairseq import checkpoint_utils, tasks +import sentencepiece as spm +import torch + +try: + from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS + from simuleval.agents import TextAgent +except ImportError: + print("Please install simuleval 'pip install simuleval'") + + +BOS_PREFIX = "\u2581" + + +class SimulTransTextAgentJA(TextAgent): + """ + Simultaneous Translation + Text agent for Japanese + """ + def __init__(self, args): + + # Whether use gpu + self.gpu = getattr(args, "gpu", False) + + # Max len + self.max_len = args.max_len + + # Load Model + self.load_model_vocab(args) + + # build word splitter + self.build_word_splitter(args) + + self.eos = DEFAULT_EOS + + def initialize_states(self, states): + states.incremental_states = dict() + states.incremental_states["online"] = dict() + + def to_device(self, tensor): + if self.gpu: + return tensor.cuda() + else: + return tensor.cpu() + + def load_model_vocab(self, args): + + filename = args.model_path + if not os.path.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + + state = checkpoint_utils.load_checkpoint_to_cpu(filename) + + task_args = state["cfg"]["task"] + task_args.data = args.data_bin + + task = tasks.setup_task(task_args) + + # build model for ensemble + state["cfg"]["model"].load_pretrained_encoder_from = None + state["cfg"]["model"].load_pretrained_decoder_from = None + + self.model = task.build_model(state["cfg"]["model"]) + self.model.load_state_dict(state["model"], strict=True) + self.model.eval() + self.model.share_memory() + + if self.gpu: + self.model.cuda() + + # Set dictionary + self.dict = {} + self.dict["tgt"] = task.target_dictionary + self.dict["src"] = task.source_dictionary + + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--model-path', type=str, required=True, + help='path to your pretrained model.') + parser.add_argument("--data-bin", type=str, required=True, + help="Path of data binary") + parser.add_argument("--max-len", type=int, default=100, + help="Max length of translation") + parser.add_argument("--tgt-splitter-type", type=str, default="SentencePiece", + help="Subword splitter type for target text.") + parser.add_argument("--tgt-splitter-path", type=str, default=None, + help="Subword splitter model path for target text.") + parser.add_argument("--src-splitter-type", type=str, default="SentencePiece", + help="Subword splitter type for source text.") + parser.add_argument("--src-splitter-path", type=str, default=None, + help="Subword splitter model path for source text.") + # fmt: on + return parser + + def build_word_splitter(self, args): + self.spm = {} + for lang in ['src', 'tgt']: + if getattr(args, f'{lang}_splitter_type', None): + path = getattr(args, f'{lang}_splitter_path', None) + if path: + self.spm[lang] = spm.SentencePieceProcessor() + self.spm[lang].Load(path) + + def segment_to_units(self, segment, states): + # Split a full word (segment) into subwords (units) + return self.spm['src'].EncodeAsPieces(segment) + + def update_model_encoder(self, states): + if len(states.units.source) == 0: + return + + src_indices = [ + self.dict['src'].index(x) + for x in states.units.source.value + ] + + if states.finish_read(): + # Append the eos index when the prediction is over + src_indices += [self.dict["tgt"].eos_index] + + src_indices = self.to_device( + torch.LongTensor(src_indices).unsqueeze(0) + ) + src_lengths = self.to_device( + torch.LongTensor([src_indices.size(1)]) + ) + + states.encoder_states = self.model.encoder(src_indices, src_lengths) + + torch.cuda.empty_cache() + + def update_states_read(self, states): + # Happens after a read action. + self.update_model_encoder(states) + + def units_to_segment(self, units, states): + # Merge sub words (units) to full word (segment). + # For Japanese, we can directly send + # the untokenized token to server except the BOS token + # with following option + # --sacrebleu-tokenizer MeCab + # --eval-latency-unit char + # --no-space + token = units.value.pop() + + if ( + token == self.dict["tgt"].eos_word + or len(states.segments.target) > self.max_len + ): + return DEFAULT_EOS + + if BOS_PREFIX == token: + return None + if token[0] == BOS_PREFIX: + return token[1:] + else: + return token + + def policy(self, states): + + if not getattr(states, "encoder_states", None): + # No encoder states, read a token first + return READ_ACTION + + # encode previous predicted target tokens + tgt_indices = self.to_device( + torch.LongTensor( + [self.model.decoder.dictionary.eos()] + + [ + self.dict['tgt'].index(x) + for x in states.units.target.value + if x is not None + ] + ).unsqueeze(0) + ) + + # Current steps + states.incremental_states["steps"] = { + "src": states.encoder_states["encoder_out"][0].size(0), + "tgt": 1 + len(states.units.target), + } + + # Online only means the reading is not finished + states.incremental_states["online"]["only"] = ( + torch.BoolTensor([not states.finish_read()]) + ) + + x, outputs = self.model.decoder.forward( + prev_output_tokens=tgt_indices, + encoder_out=states.encoder_states, + incremental_state=states.incremental_states, + ) + + states.decoder_out = x + + torch.cuda.empty_cache() + + if outputs.action == 0: + return READ_ACTION + else: + return WRITE_ACTION + + def predict(self, states): + # Predict target token from decoder states + decoder_states = states.decoder_out + + lprobs = self.model.get_normalized_probs( + [decoder_states[:, -1:]], log_probs=True + ) + + index = lprobs.argmax(dim=-1)[0, 0].item() + + if index != self.dict['tgt'].eos_index: + token = self.dict['tgt'].string([index]) + else: + token = self.dict['tgt'].eos_word + + return token diff --git a/examples/simultaneous_translation/eval/agents/simul_trans_agent.py b/examples/simultaneous_translation/eval/agents/simul_trans_agent.py deleted file mode 100644 index 071b9e89ce..0000000000 --- a/examples/simultaneous_translation/eval/agents/simul_trans_agent.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import json -import os - -from fairseq import checkpoint_utils, tasks, utils - -from . import DEFAULT_EOS, GET, SEND -from .agent import Agent - - -class SimulTransAgent(Agent): - def __init__(self, args): - # Load Model - self.load_model(args) - - # build word spliter - self.build_word_splitter(args) - - self.max_len = args.max_len - - self.eos = DEFAULT_EOS - - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--model-path', type=str, required=True, - help='path to your pretrained model.') - parser.add_argument("--data-bin", type=str, required=True, - help="Path of data binary") - parser.add_argument("--user-dir", type=str, default="example/simultaneous_translation", - help="User directory for simultaneous translation") - parser.add_argument("--src-splitter-type", type=str, default=None, - help="Subword splitter type for source text") - parser.add_argument("--tgt-splitter-type", type=str, default=None, - help="Subword splitter type for target text") - parser.add_argument("--src-splitter-path", type=str, default=None, - help="Subword splitter model path for source text") - parser.add_argument("--tgt-splitter-path", type=str, default=None, - help="Subword splitter model path for target text") - parser.add_argument("--max-len", type=int, default=150, - help="Maximum length difference between source and target prediction") - parser.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', - help='A dictionary used to override model args at generation ' - 'that were used during model training') - # fmt: on - return parser - - def load_dictionary(self, task): - raise NotImplementedError - - def load_model(self, args): - args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..") - utils.import_user_module(args) - filename = args.model_path - if not os.path.exists(filename): - raise IOError("Model file not found: {}".format(filename)) - - state = checkpoint_utils.load_checkpoint_to_cpu( - filename, json.loads(args.model_overrides) - ) - - saved_args = state["args"] - saved_args.data = args.data_bin - - task = tasks.setup_task(saved_args) - - # build model for ensemble - self.model = task.build_model(saved_args) - self.model.load_state_dict(state["model"], strict=True) - - # Set dictionary - self.load_dictionary(task) - - def init_states(self): - return { - "indices": {"src": [], "tgt": []}, - "tokens": {"src": [], "tgt": []}, - "segments": {"src": [], "tgt": []}, - "steps": {"src": 0, "tgt": 0}, - "finished": False, - "finish_read": False, - "model_states": {}, - } - - def update_states(self, states, new_state): - raise NotImplementedError - - def policy(self, states): - # Read and Write policy - action = None - - while action is None: - if states["finished"]: - # Finish the hypo by sending eos to server - return self.finish_action() - - # Model make decision given current states - decision = self.model.decision_from_states(states) - - if decision == 0 and not self.finish_read(states): - # READ - action = self.read_action(states) - else: - # WRITE - action = self.write_action(states) - - # None means we make decision again but not sending server anything - # This happened when read a bufffered token - # Or predict a subword - return action - - def finish_read(self, states): - raise NotImplementedError - - def write_action(self, states): - token, index = self.model.predict_from_states(states) - - if ( - index == self.dict["tgt"].eos() - or len(states["tokens"]["tgt"]) > self.max_len - ): - # Finish this sentence is predict EOS - states["finished"] = True - end_idx_last_full_word = self._target_length(states) - - else: - states["tokens"]["tgt"] += [token] - end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word( - states["tokens"]["tgt"] - ) - self._append_indices(states, [index], "tgt") - - if end_idx_last_full_word > states["steps"]["tgt"]: - # Only sent detokenized full words to the server - word = self.word_splitter["tgt"].merge( - states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word] - ) - states["steps"]["tgt"] = end_idx_last_full_word - states["segments"]["tgt"] += [word] - - return {"key": SEND, "value": word} - else: - return None - - def read_action(self, states): - return {"key": GET, "value": None} - - def finish_action(self): - return {"key": SEND, "value": DEFAULT_EOS} - - def reset(self): - pass - - def finish_eval(self, states, new_state): - if len(new_state) == 0 and len(states["indices"]["src"]) == 0: - return True - return False - - def _append_indices(self, states, new_indices, key): - states["indices"][key] += new_indices - - def _target_length(self, states): - return len(states["tokens"]["tgt"]) diff --git a/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py b/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py deleted file mode 100644 index 7c34817bf6..0000000000 --- a/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from . import DEFAULT_EOS, GET, register_agent -from .simul_trans_agent import SimulTransAgent -from .word_splitter import SPLITTER_DICT - - -@register_agent("simul_trans_text") -class SimulTransTextAgent(SimulTransAgent): - def build_word_splitter(self, args): - self.word_splitter = {} - - self.word_splitter["src"] = SPLITTER_DICT[args.src_splitter_type]( - getattr(args, f"src_splitter_path") - ) - self.word_splitter["tgt"] = SPLITTER_DICT[args.tgt_splitter_type]( - getattr(args, f"tgt_splitter_path") - ) - - def load_dictionary(self, task): - self.dict = {} - self.dict["tgt"] = task.target_dictionary - self.dict["src"] = task.source_dictionary - - def update_states(self, states, new_state): - if states["finish_read"]: - return states - - new_word = new_state["segment"] - - # Split words and index the token - if new_word not in [DEFAULT_EOS]: - tokens = self.word_splitter["src"].split(new_word) - # Get indices from dictionary - # You can change to you own dictionary - indices = ( - self.dict["src"] - .encode_line( - tokens, - line_tokenizer=lambda x: x, - add_if_not_exist=False, - append_eos=False, - ) - .tolist() - ) - else: - tokens = [new_word] - indices = [self.dict["src"].eos()] - states["finish_read"] = True - - # Update states - states["segments"]["src"] += [new_word] - states["tokens"]["src"] += tokens - self._append_indices(states, indices, "src") - - return states - - def read_action(self, states): - # Increase source step by one - states["steps"]["src"] += 1 - - # At leat one word is read - if len(states["tokens"]["src"]) == 0: - return {"key": GET, "value": None} - - # Only request new word if there is no buffered tokens - if len(states["tokens"]["src"]) <= states["steps"]["src"]: - return {"key": GET, "value": None} - - return None - - def finish_read(self, states): - # The first means all segments (full words) has been read from server - # The second means all tokens (subwords) has been read locally - return ( - states["finish_read"] - and len(states["tokens"]["src"]) == states["steps"]["src"] - ) diff --git a/examples/simultaneous_translation/eval/agents/word_splitter.py b/examples/simultaneous_translation/eval/agents/word_splitter.py deleted file mode 100644 index c3f71200a5..0000000000 --- a/examples/simultaneous_translation/eval/agents/word_splitter.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - - -class SubwordSplitter(object): - def process_line(self, string): - raise NotImplementedError - - def split(self, string): - raise NotImplementedError - - -class NoneWordSplitter(object): - def __init__(self, model): - pass - - def split(self, string): - return [string] - - def process_line(self, string): - return [string] - - def finished_word(self, string): - return True - - def merge(self, list_of_string): - return "".join(list_of_string) - - def last_full_word_step(self, tokens, step): - return len(tokens) - - def end_idx_last_full_word(self, tokens): - return len(tokens) - - -class BPEWordSplitter(object): - # TODO: lock back here - def __init__(self, model_path): - super().__init__() - from subword_nmt.apply_bpe import BPE - - with open(model_path) as f: - self.model = BPE(f) - - def split(self, string): - return self.model.process_line(string).split() - - def end_idx_last_full_word(self, tokens): - # Begin of word indices - bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != "@@"] - - if len(bow_indices) < 2: - return 0 - else: - return bow_indices[-1] - - def merge(self, list_of_string): - return " ".join([item.replace("@@", "") for item in list_of_string]) - - -class SentencePieceModelWordSplitter(object): - def __init__(self, model_path): - super().__init__() - import sentencepiece as spm - - self.model = spm.SentencePieceProcessor() - self.model.Load(model_path) - - def split(self, string): - return self.model.EncodeAsPieces(string) - - def end_idx_last_full_word(self, tokens): - # Begin of word indices - bow_indices = [i for i, t in enumerate(tokens) if t[0] == "\u2581"] - - if len(bow_indices) < 2: - return 0 - else: - return bow_indices[-1] - - def merge(self, list_of_string): - return self.model.DecodePieces(list_of_string) - - -SPLITTER_DICT = { - None: NoneWordSplitter, - "BPE": BPEWordSplitter, - "SentencePieceModel": SentencePieceModelWordSplitter, -} diff --git a/examples/simultaneous_translation/eval/client.py b/examples/simultaneous_translation/eval/client.py deleted file mode 100644 index 3ca4ea73b8..0000000000 --- a/examples/simultaneous_translation/eval/client.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional - -import requests -from scorers import build_scorer - - -class SimulSTEvaluationService(object): - DEFAULT_HOSTNAME = "localhost" - DEFAULT_PORT = 12321 - - def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT): - self.hostname = hostname - self.port = port - self.base_url = f"http://{self.hostname}:{self.port}" - - def __enter__(self): - self.new_session() - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - def new_session(self): - # start eval session - url = f"{self.base_url}" - - try: - _ = requests.post(url) - except Exception as e: - print(f"Failed to start an evaluation session: {e}") - - print("Evaluation session started.") - return self - - def get_scores(self): - # end eval session - url = f"{self.base_url}/result" - try: - r = requests.get(url) - print("Scores: {}".format(r.json())) - print("Evaluation session finished.") - except Exception as e: - print(f"Failed to end an evaluation session: {e}") - - def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str: - url = f"{self.base_url}/src" - params = {"sent_id": sent_id} - if extra_params is not None: - for key in extra_params.keys(): - params[key] = extra_params[key] - try: - r = requests.get(url, params=params) - except Exception as e: - print(f"Failed to request a source segment: {e}") - return r.json() - - def send_hypo(self, sent_id: int, hypo: str) -> None: - url = f"{self.base_url}/hypo" - params = {"sent_id": sent_id} - - try: - requests.put(url, params=params, data=hypo.encode("utf-8")) - except Exception as e: - print(f"Failed to send a translated segment: {e}") - - def corpus_info(self): - url = f"{self.base_url}" - try: - r = requests.get(url) - except Exception as e: - print(f"Failed to request corpus information: {e}") - - return r.json() - - -class SimulSTLocalEvaluationService(object): - def __init__(self, args): - self.scorer = build_scorer(args) - - def get_scores(self): - return self.scorer.score() - - def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str: - if extra_params is not None: - segment_size = extra_params.get("segment_size", None) - else: - segment_size = None - - return self.scorer.send_src(int(sent_id), segment_size) - - def send_hypo(self, sent_id: int, hypo: str) -> None: - list_of_tokens = hypo.strip().split() - self.scorer.recv_hyp(sent_id, list_of_tokens) - - def corpus_info(self): - return self.scorer.get_info() diff --git a/examples/simultaneous_translation/eval/eval_latency.py b/examples/simultaneous_translation/eval/eval_latency.py deleted file mode 100644 index 50021de47c..0000000000 --- a/examples/simultaneous_translation/eval/eval_latency.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import json - -import torch -from examples.simultaneous_translation.utils.latency import LatencyInference - - -LATENCY_METRICS = [ - "differentiable_average_lagging", - "average_lagging", - "average_proportion", -] - - -class LatencyScorer: - def __init__(self, start_from_zero=True): - self.recorder = [] - self.scores = {} - self.scorer = LatencyInference() - self.start_from_zero = start_from_zero - - def update_reorder(self, list_of_dict): - self.recorder = [] - for info in list_of_dict: - delays = [int(x) - int(not self.start_from_zero) for x in info["delays"]] - delays = torch.LongTensor(delays).unsqueeze(0) - src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0) - - self.recorder.append(self.scorer(delays, src_len)) - - def cal_latency(self): - self.scores = {} - for metric in LATENCY_METRICS: - self.scores[metric] = sum( - [x[metric][0, 0].item() for x in self.recorder] - ) / len(self.recorder) - return self.scores - - @classmethod - def score(cls, list_of_dict, start_from_zero=True): - scorer_to_return = cls(start_from_zero) - scorer_to_return.update_reorder(list_of_dict) - scorer_to_return.cal_latency() - return scorer_to_return.scores - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--input", required=True) - parser.add_argument("--start-from-zero", action="store_true") - args = parser.parse_args() - - scorer = LatencyInference() - recorder = [] - with open(args.input, "r") as f: - for line in f: - info = json.loads(line) - - delays = [int(x) - int(not args.start_from_zero) for x in info["delays"]] - - delays = torch.LongTensor(delays).unsqueeze(0) - - src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0) - - recorder.append(scorer(delays, src_len)) - - average_results = {} - - for metric in LATENCY_METRICS: - average_results[metric] = sum([x[metric][0, 0].item() for x in recorder]) / len( - recorder - ) - print(f"{metric}: {average_results[metric]}") diff --git a/examples/simultaneous_translation/eval/evaluate.py b/examples/simultaneous_translation/eval/evaluate.py deleted file mode 100644 index 2f7474621a..0000000000 --- a/examples/simultaneous_translation/eval/evaluate.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import argparse - -from agents import build_agent -from client import SimulSTEvaluationService, SimulSTLocalEvaluationService -from fairseq.registry import REGISTRIES - - -DEFAULT_HOSTNAME = "localhost" -DEFAULT_PORT = 12321 - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname" - ) - parser.add_argument( - "--port", type=int, default=DEFAULT_PORT, help="server port number" - ) - parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type") - parser.add_argument("--scorer-type", default="text", help="Scorer type") - parser.add_argument( - "--start-idx", - type=int, - default=0, - help="Start index of the sentence to evaluate", - ) - parser.add_argument( - "--end-idx", - type=int, - default=float("inf"), - help="End index of the sentence to evaluate", - ) - parser.add_argument( - "--scores", action="store_true", help="Request scores from server" - ) - parser.add_argument("--reset-server", action="store_true", help="Reset the server") - parser.add_argument( - "--num-threads", type=int, default=10, help="Number of threads used by agent" - ) - parser.add_argument( - "--local", action="store_true", default=False, help="Local evaluation" - ) - - args, _ = parser.parse_known_args() - - for registry_name, REGISTRY in REGISTRIES.items(): - choice = getattr(args, registry_name, None) - if choice is not None: - cls = REGISTRY["registry"][choice] - if hasattr(cls, "add_args"): - cls.add_args(parser) - args = parser.parse_args() - - return args - - -if __name__ == "__main__": - args = get_args() - - if args.local: - session = SimulSTLocalEvaluationService(args) - else: - session = SimulSTEvaluationService(args.hostname, args.port) - - if args.reset_server: - session.new_session() - - if args.agent_type is not None: - agent = build_agent(args) - agent.decode(session, args.start_idx, args.end_idx, args.num_threads) - - if args.scores: - session.get_scores() - print(session.get_scores()) diff --git a/examples/simultaneous_translation/eval/scorers/__init__.py b/examples/simultaneous_translation/eval/scorers/__init__.py deleted file mode 100644 index 0a0e0a0518..0000000000 --- a/examples/simultaneous_translation/eval/scorers/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import importlib -import os - -from fairseq import registry - - -(build_scorer, register_scorer, SCORER_REGISTRIES, _) = registry.setup_registry( - "--scorer-type" -) - -for file in os.listdir(os.path.dirname(__file__)): - if file.endswith(".py") and not file.startswith("_"): - module = file[: file.find(".py")] - importlib.import_module("scorers." + module) diff --git a/examples/simultaneous_translation/eval/scorers/scorer.py b/examples/simultaneous_translation/eval/scorers/scorer.py deleted file mode 100644 index d6d3e30aef..0000000000 --- a/examples/simultaneous_translation/eval/scorers/scorer.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import json -import os -from collections import defaultdict - -from examples.simultaneous_translation.eval.eval_latency import LatencyScorer -from vizseq.scorers.bleu import BLEUScorer -from vizseq.scorers.meteor import METEORScorer -from vizseq.scorers.ter import TERScorer - - -DEFAULT_EOS = "" - - -class SimulScorer(object): - def __init__(self, args): - self.tokenizer = args.tokenizer - self.output_dir = args.output - if args.output is not None: - self.output_files = { - "text": os.path.join(args.output, "text"), - "delay": os.path.join(args.output, "delay"), - "scores": os.path.join(args.output, "scores"), - } - else: - self.output_files = None - self.eos = DEFAULT_EOS - self.data = {"tgt": []} - self.reset() - - def get_info(self): - return {"num_sentences": len(self)} - - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--src-file', type=str, required=True, - help='Source input file') - parser.add_argument('--tgt-file', type=str, required=True, - help='Target reference file') - parser.add_argument('--tokenizer', default="13a", choices=["none", "13a"], - help='Tokenizer used for sacrebleu') - parser.add_argument('--output', type=str, default=None, - help='Path for output directory') - # fmt: on - - def send_src(self, sent_id, *args): - raise NotImplementedError - - def recv_hyp(self, sent_id, list_of_tokens): - for token in list_of_tokens: - self.translations[sent_id].append((token, self.steps[sent_id])) - - def reset(self): - self.steps = defaultdict(int) - self.translations = defaultdict(list) - - def src_lengths(self): - raise NotImplementedError - - def score(self): - translations = [] - delays = [] - for i in range(1 + max(self.translations.keys())): - translations += [" ".join(t[0] for t in self.translations[i][:-1])] - delays += [[t[1] for t in self.translations[i]]] - - bleu_score = BLEUScorer( - sent_level=False, - corpus_level=True, - extra_args={"bleu_tokenizer": self.tokenizer}, - ).score(translations, [self.data["tgt"]]) - - ter_score = TERScorer(sent_level=False, corpus_level=True).score( - translations, [self.data["tgt"]] - ) - meteor_score = METEORScorer(sent_level=False, corpus_level=True).score( - translations, [self.data["tgt"]] - ) - - latency_score = LatencyScorer().score( - [ - {"src_len": src_len, "delays": delay} - for src_len, delay in zip(self.src_lengths(), delays) - ], - start_from_zero=False, - ) - - scores = { - "BLEU": bleu_score[0], - "TER": ter_score[0], - "METEOR": meteor_score[0], - "DAL": latency_score["differentiable_average_lagging"], - "AL": latency_score["average_lagging"], - "AP": latency_score["average_proportion"], - } - - if self.output_files is not None: - try: - os.makedirs(self.output_dir, exist_ok=True) - self.write_results_to_file(translations, delays, scores) - except BaseException as be: - print(f"Failed to write results to {self.output_dir}.") - print(be) - print("Skip writing predictions") - - return scores - - def write_results_to_file(self, translations, delays, scores): - if self.output_files["text"] is not None: - with open(self.output_files["text"], "w") as f: - for line in translations: - f.write(line + "\n") - - if self.output_files["delay"] is not None: - with open(self.output_files["delay"], "w") as f: - for i, delay in enumerate(delays): - f.write( - json.dumps({"src_len": self.src_lengths()[i], "delays": delay}) - + "\n" - ) - - with open(self.output_files["scores"], "w") as f: - for key, value in scores.items(): - f.write(f"{key}, {value}\n") - - @classmethod - def _load_text_file(cls, file, split=False): - with open(file) as f: - if split: - return [r.strip().split() for r in f] - else: - return [r.strip() for r in f] - - @classmethod - def _load_text_from_json(cls, file): - list_to_return = [] - with open(file) as f: - content = json.load(f) - for item in content["utts"].values(): - list_to_return.append(item["output"]["text"].strip()) - return list_to_return - - @classmethod - def _load_wav_info_from_json(cls, file): - list_to_return = [] - with open(file) as f: - content = json.load(f) - for item in content["utts"].values(): - list_to_return.append( - { - "path": item["input"]["path"].strip(), - "length": item["input"]["length_ms"], - } - ) - return list_to_return - - @classmethod - def _load_wav_info_from_list(cls, file): - list_to_return = [] - with open(file) as f: - for line in f: - list_to_return.append( - { - "path": line.strip(), - } - ) - return list_to_return - - def __len__(self): - return len(self.data["tgt"]) diff --git a/examples/simultaneous_translation/eval/scorers/text_scorer.py b/examples/simultaneous_translation/eval/scorers/text_scorer.py deleted file mode 100644 index 649a2c7e5c..0000000000 --- a/examples/simultaneous_translation/eval/scorers/text_scorer.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from . import register_scorer -from .scorer import SimulScorer - - -@register_scorer("text") -class SimulTextScorer(SimulScorer): - def __init__(self, args): - super().__init__(args) - self.data = { - "src": self._load_text_file(args.src_file, split=True), - "tgt": self._load_text_file(args.tgt_file, split=False), - } - - def send_src(self, sent_id, *args): - if self.steps[sent_id] >= len(self.data["src"][sent_id]): - dict_to_return = { - "sent_id": sent_id, - "segment_id": self.steps[sent_id], - "segment": self.eos, - } - # Consider EOS - self.steps[sent_id] = len(self.data["src"][sent_id]) + 1 - else: - dict_to_return = { - "sent_id": sent_id, - "segment_id": self.steps[sent_id], - "segment": self.data["src"][sent_id][self.steps[sent_id]], - } - - self.steps[sent_id] += 1 - - return dict_to_return - - def src_lengths(self): - # +1 for eos - return [len(sent) + 1 for sent in self.data["src"]] diff --git a/examples/simultaneous_translation/eval/server.py b/examples/simultaneous_translation/eval/server.py deleted file mode 100644 index e44ceaff85..0000000000 --- a/examples/simultaneous_translation/eval/server.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -import argparse -import json -import sys - -from scorers import build_scorer -from tornado import ioloop, web - - -DEFAULT_HOSTNAME = "localhost" -DEFAULT_PORT = 12321 - - -class ScorerHandler(web.RequestHandler): - def initialize(self, scorer): - self.scorer = scorer - - -class EvalSessionHandler(ScorerHandler): - def post(self): - self.scorer.reset() - - def get(self): - r = json.dumps(self.scorer.get_info()) - self.write(r) - - -class ResultHandler(ScorerHandler): - def get(self): - r = json.dumps(self.scorer.score()) - self.write(r) - - -class SourceHandler(ScorerHandler): - def get(self): - sent_id = int(self.get_argument("sent_id")) - segment_size = None - if "segment_size" in self.request.arguments: - string = self.get_argument("segment_size") - if len(string) > 0: - segment_size = int(string) - - r = json.dumps(self.scorer.send_src(int(sent_id), segment_size)) - - self.write(r) - - -class HypothesisHandler(ScorerHandler): - def put(self): - sent_id = int(self.get_argument("sent_id")) - list_of_tokens = self.request.body.decode("utf-8").strip().split() - self.scorer.recv_hyp(sent_id, list_of_tokens) - - -def add_args(): - parser = argparse.ArgumentParser() - # fmt: off - parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME, - help='Server hostname') - parser.add_argument('--port', type=int, default=DEFAULT_PORT, - help='Server port number') - - args, _ = parser.parse_known_args() - # fmt: on - return args - - -def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False): - app = web.Application( - [ - (r"/result", ResultHandler, dict(scorer=scorer)), - (r"/src", SourceHandler, dict(scorer=scorer)), - (r"/hypo", HypothesisHandler, dict(scorer=scorer)), - (r"/", EvalSessionHandler, dict(scorer=scorer)), - ], - debug=debug, - ) - app.listen(port, max_buffer_size=1024 ** 3) - sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n") - ioloop.IOLoop.current().start() - - -if __name__ == "__main__": - args = add_args() - scorer = build_scorer(args) - start_server(scorer, args.hostname, args.port, args.debug) diff --git a/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py b/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py deleted file mode 100644 index 45df5fa227..0000000000 --- a/examples/speech_to_text/simultaneous_translation/agents/simul_trans_agent.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import json -import os - -from fairseq import checkpoint_utils, utils, tasks - -from . import DEFAULT_EOS, GET, SEND -from .agent import Agent - - -class SimulTransAgent(Agent): - def __init__(self, args): - # Load Model - self.load_model(args) - - # build word spliter - self.build_word_splitter(args) - - self.max_len = args.max_len - - self.eos = DEFAULT_EOS - - @staticmethod - def add_args(parser): - parser.add_argument( - "--model-path", - type=str, - required=True, - help="path to your pretrained model.", - ) - parser.add_argument( - "--data-bin", type=str, required=True, help="Path of data binary" - ) - parser.add_argument( - "--user-dir", - type=str, - default="example/simultaneous_translation", - help="User directory for simultaneous translation", - ) - parser.add_argument( - "--src-splitter-type", - type=str, - default=None, - help="Subword splitter type for source text", - ) - parser.add_argument( - "--tgt-splitter-type", - type=str, - default=None, - help="Subword splitter type for target text", - ) - parser.add_argument( - "--src-splitter-path", - type=str, - default=None, - help="Subword splitter model path for source text", - ) - parser.add_argument( - "--tgt-splitter-path", - type=str, - default=None, - help="Subword splitter model path for target text", - ) - parser.add_argument( - "--max-len", - type=int, - default=150, - help="Maximum length difference between source and target prediction", - ) - parser.add_argument( - "--model-overrides", - default="{}", - type=str, - metavar="DICT", - help="A dictionary used to override model args at generation " - "that were used during model training", - ) - # fmt: on - return parser - - def load_dictionary(self, task): - raise NotImplementedError - - def load_model(self, args): - args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..") - utils.import_user_module(args) - filename = args.model_path - if not os.path.exists(filename): - raise IOError("Model file not found: {}".format(filename)) - - state = checkpoint_utils.load_checkpoint_to_cpu( - filename, json.loads(args.model_overrides) - ) - - saved_args = state["args"] - saved_args.data = args.data_bin - - task = tasks.setup_task(saved_args) - - # build model for ensemble - self.model = task.build_model(saved_args) - self.model.load_state_dict(state["model"], strict=True) - - # Set dictionary - self.load_dictionary(task) - - def init_states(self): - return { - "indices": {"src": [], "tgt": []}, - "tokens": {"src": [], "tgt": []}, - "segments": {"src": [], "tgt": []}, - "steps": {"src": 0, "tgt": 0}, - "finished": False, - "finish_read": False, - "model_states": {}, - } - - def update_states(self, states, new_state): - raise NotImplementedError - - def policy(self, states): - # Read and Write policy - action = None - - while action is None: - if states["finished"]: - # Finish the hypo by sending eos to server - return self.finish_action() - - # Model make decision given current states - decision = self.model.decision_from_states(states) - - if decision == 0 and not self.finish_read(states): - # READ - action = self.read_action(states) - else: - # WRITE - action = self.write_action(states) - - # None means we make decision again but not sending server anything - # This happened when read a buffered token - # Or predict a subword - return action - - def finish_read(self, states): - raise NotImplementedError - - def write_action(self, states): - token, index = self.model.predict_from_states(states) - - if ( - index == self.dict["tgt"].eos() - or len(states["tokens"]["tgt"]) > self.max_len - ): - # Finish this sentence is predict EOS - states["finished"] = True - end_idx_last_full_word = self._target_length(states) - - else: - states["tokens"]["tgt"] += [token] - end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word( - states["tokens"]["tgt"] - ) - self._append_indices(states, [index], "tgt") - - if end_idx_last_full_word > states["steps"]["tgt"]: - # Only sent detokenized full words to the server - word = self.word_splitter["tgt"].merge( - states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word] - ) - states["steps"]["tgt"] = end_idx_last_full_word - states["segments"]["tgt"] += [word] - - return {"key": SEND, "value": word} - else: - return None - - def read_action(self, states): - return {"key": GET, "value": None} - - def finish_action(self): - return {"key": SEND, "value": DEFAULT_EOS} - - def reset(self): - pass - - def finish_eval(self, states, new_state): - if len(new_state) == 0 and len(states["indices"]["src"]) == 0: - return True - return False - - def _append_indices(self, states, new_indices, key): - states["indices"][key] += new_indices - - def _target_length(self, states): - return len(states["tokens"]["tgt"]) diff --git a/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py index aa3dba31e2..051785238f 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from examples.simultaneous_translation.utils.latency import LatencyTraining from fairseq.criterions import register_criterion from fairseq.criterions.label_smoothed_cross_entropy import ( LabelSmoothedCrossEntropyCriterion, @@ -31,6 +30,7 @@ def __init__( super().__init__( task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy ) + from examples.simultaneous_translation.utils.latency import LatencyTraining self.eps = label_smoothing self.latency_weight_avg = latency_weight_avg self.latency_weight_avg_type = latency_weight_avg_type diff --git a/fairseq/tasks/simultaneous_translation.py b/fairseq/tasks/simultaneous_translation.py new file mode 100644 index 0000000000..11c7dc1ea9 --- /dev/null +++ b/fairseq/tasks/simultaneous_translation.py @@ -0,0 +1,42 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from fairseq.tasks import register_task +from fairseq.tasks.speech_to_text import SpeechToTextTask +from fairseq.tasks.translation import ( + TranslationTask, TranslationConfig +) + +try: + import examples.simultaneous_translation # noqa + import_successful = True +except BaseException: + import_successful = False + + +logger = logging.getLogger(__name__) + + +def check_import(flag): + if not flag: + raise ImportError( + "'examples.simultaneous_translation' is not correctly imported. " + "Please considering `pip install -e $FAIRSEQ_DIR`." + ) + + +@register_task("simul_speech_to_text") +class SimulSpeechToTextTask(SpeechToTextTask): + def __init__(self, args, tgt_dict): + check_import(import_successful) + super().__init__(args, tgt_dict) + + +@register_task("simul_text_to_text", dataclass=TranslationConfig) +class SimulTextToTextTask(TranslationTask): + def __init__(self, cfg, src_dict, tgt_dict): + check_import(import_successful) + super().__init__(cfg, src_dict, tgt_dict)