-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement language model support, decoder only (#1919)
Thanks @funboarder13920 !
- Loading branch information
1 parent
fa7f888
commit 7cbfbb3
Showing
39 changed files
with
2,282 additions
and
643 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
you ! | ||
ignored . | ||
elections . | ||
. | ||
<unk> works . | ||
codec to be available soon . | ||
forty years ago . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
you ! | ||
that the Irish problem in the Commission should not only a topical problem . | ||
elections , Israel will be developed to start of the crisis , at the crisis , at the crisis , at the crisis , during the crisis , at the crisis . | ||
and <unk> | ||
the July 2003 , has been developed to win <unk> public - thus been developed countries . | ||
might have been <unk> . | ||
<unk> , I think we are going to be able to make it . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Orlando Bloom and Miranda Kerr still love | ||
However , in an interview , Bloom has said | ||
Jet makers feud over seat width with | ||
A row has flared up between leading plane | ||
The dispute focuses on the width of | ||
They don 't want us to dictate to them what | ||
They know their business better than |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Corpus opts: | ||
data: | ||
corpus_1: | ||
path_src: data/src-train.txt | ||
valid: | ||
path_src: data/src-val.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Language Model Generation | ||
|
||
|
||
## Step 0: Download and clean the data | ||
|
||
Preliminary steps are defined in the [`examples/scripts/prepare_wikitext-103_data.sh`](https://github.com/OpenNMT/OpenNMT-py/tree/master/examples/scripts/prepare_wikitext-103_data.sh). The following command will download the [WikiText103 dataset](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/), remove empty lines and shuffle the training corpus: | ||
```bash | ||
chmod u+x prepare_wikitext-103_data.sh | ||
./prepare_wikitext-103_data.sh | ||
``` | ||
|
||
## Step 1: Prepare the subword model - BPE with pyonmttok | ||
|
||
This snippet will train a bpe of 40000 symbols on the training dataset using pyonmttok. The bpe model will be stored in `subwords.bpe` and the train/valid/test sets will be tokenized and saved. | ||
|
||
The tokenized files won't be used for training. Indeed, dynamic iteration over the training dataset enables on the fly tokenization using transforms (see step 2). | ||
|
||
```python | ||
import pyonmttok | ||
|
||
args = { | ||
"mode": "aggressive", | ||
"joiner_annotate": True, | ||
"preserve_placeholders": True, | ||
"case_markup": True, | ||
"soft_case_regions": True, | ||
"preserve_segmented_tokens": True, | ||
} | ||
n_symbols = 40000 | ||
|
||
tokenizer_default = pyonmttok.Tokenizer(**args) | ||
learner = pyonmttok.BPELearner(tokenizer=tokenizer_default, symbols=n_symbols) | ||
# load training corpus | ||
learner.ingest_file("wiki.train.raw") | ||
|
||
# learn and store bpe model | ||
tokenizer = learner.learn("subwords.bpe") | ||
|
||
# tokenize corpus and save results | ||
for data_file in ["wiki.valid", "wiki.test", "wiki.train"]: | ||
tokenizer.tokenize_file(f"{data_file}.raw", f"{data_file}.bpe") | ||
``` | ||
|
||
## Step 2: Build the vocabulary | ||
An example of yaml configuration for language modeling task is available in [`examples/wiki_103.yaml`](https://github.com/OpenNMT/OpenNMT-py/tree/master/examples/wiki_103.yaml). This configuration will be used for building the vocabulary and training the model. | ||
BPE and language modeling specificities are explained in the following sections. | ||
|
||
### Language Model specificities | ||
|
||
In LM tasks we expect a single source, therefore path_tgt is not required for LM tasks. | ||
|
||
```yaml | ||
data: | ||
corpus_1: | ||
path_src: data/wikitext-103-raw/wiki.train.raw | ||
``` | ||
### BPE specificities | ||
To use BPE tokenization on the fly, the following parameters must be in the config file. | ||
Slight differences between on the fly tokenization and outputed tokenized files from step 1 can be observed. | ||
```yaml | ||
src_subword_type: bpe | ||
src_subword_model: data/wikitext-103-raw/subwords.bpe | ||
src_onmttok_kwargs: '{"mode": "aggressive", "joiner_annotate": True, "preserve_placeholders": | ||
True, "case_markup": True, "soft_case_regions": True, "preserve_segmented_tokens": | ||
True}' | ||
transforms: [onmt_tokenize] | ||
``` | ||
### Build vocabulary command | ||
The vocabulary is built using: | ||
```bash | ||
onmt_build_vocab -config examples/wiki_103.yaml -n_sample -1 | ||
``` | ||
|
||
## Step 3: Train the model | ||
To train a model for LM tasks, the following parameters are required: | ||
|
||
* *model_task* is used to specify that the task will be language modeling (decoder only model with tansformer_lm decoder type, source only dataset expected) | ||
* *decoder_type* must be transformer_lm. This transformer is the one used in GPT-2: [**Language Models are Unsupervised Multitask Learners**](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf). Basically, it is a transformer without an encoder attention block | ||
* *encoder_type* is not useful but need to be mentionned | ||
* *share_vocab* must be true. The slided source will play the role of the target hence vocabulary must be shared. | ||
```yaml | ||
model_task: lm | ||
encoder_type: transformer_lm | ||
decoder_type: transformer_lm | ||
|
||
share_vocab: true | ||
``` | ||
The training is launched using: | ||
```bash | ||
onmt_train -config examples/wiki_103.yaml | ||
``` | ||
Tensorboard can be used to monitor the training. | ||
|
||
**Expected results:** perplexity of 20-22 on the validation set. | ||
|
||
## Step 4: Generate output | ||
Options contained in the loaded model will trigger language modeling specific inference. | ||
|
||
`input.txt` must contain already tokenized examples, with the same method as the training data. Here, part of validation data will be used: | ||
```bash | ||
head data/wikitext-103-raw/wiki.valid.bpe | cut -d" " -f-15 > data/wikitext-103-raw/lm_input.txt | ||
``` | ||
|
||
To proceed with inference: | ||
```bash | ||
onmt_translate -model data/wikitext-103-raw/run/model-lm_step_1000000.pt -src data/wikitext-103-raw/lm_input.txt -output data/wikitext-103-raw/lm_pred_input.txt -verbose -n_best 3 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#!/bin/bash | ||
|
||
################################################################################## | ||
# This script will download wikitext-103-raw and will do basic data preparation | ||
# for BPE and training | ||
################################################################################## | ||
|
||
# provide script usage instructions | ||
if [ $# -eq 0 ] | ||
then | ||
echo "usage: $0 <data_dir>" | ||
exit 1 | ||
fi | ||
|
||
get_seeded_random() | ||
{ | ||
seed="$1" | ||
openssl enc -aes-256-ctr -pass pass:"$seed" -nosalt \ | ||
</dev/zero 2>/dev/null | ||
} | ||
|
||
# set relevant paths | ||
SP_PATH=/usr/local/bin | ||
DATA_PATH=$1 | ||
TEST_PATH=$DATA_PATH/test | ||
|
||
CUR_DIR=$(pwd) | ||
|
||
# Download the default datasets into the $DATA_PATH; mkdir if it doesn't exist | ||
mkdir -p $DATA_PATH | ||
cd $DATA_PATH | ||
|
||
echo "Downloading and extracting WikiText-103 (183 MB) for training and inference..." | ||
wget --trust-server-names https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip | ||
unzip wikitext-103-raw-v1.zip | ||
rm wikitext-103-raw-v1.zip | ||
cd wikitext-103-raw | ||
|
||
echo "Removing empty lines and shuffling training data" | ||
sed -r '/^\s*$/d' -i wiki.train.raw | ||
sed -r '/^\s*$/d' -i wiki.valid.raw | ||
sed -r '/^\s*$/d' -i wiki.test.raw | ||
sort --random-source=<(get_seeded_random 42) -R -o wiki.train.raw wiki.train.raw |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
|
||
seed: 42 | ||
share_vocab: true | ||
save_data: data/wikitext-103-raw/run/example | ||
## Where the vocab(s) will be written | ||
src_vocab: data/wikitext-103-raw/run/example.vocab.src | ||
src_vocab_size: 60000 | ||
tgt_vocab_size: 60000 | ||
src_subword_type: bpe | ||
src_subword_model: data/wikitext-103-raw/subwords.bpe | ||
src_onmttok_kwargs: '{"mode": "aggressive", "joiner_annotate": True, "preserve_placeholders": | ||
True, "case_markup": True, "soft_case_regions": True, "preserve_segmented_tokens": | ||
True}' | ||
transforms: [onmt_tokenize, filtertoolong] | ||
src_seq_length: 512 | ||
tgt_seq_length: 512 | ||
|
||
# Prevent overwriting existing files in the folder | ||
overwrite: True | ||
|
||
# Corpus opts: | ||
data: | ||
corpus_1: | ||
path_src: data/wikitext-103-raw/wiki.train.raw | ||
valid: | ||
path_src: data/wikitext-103-raw/wiki.valid.raw | ||
|
||
|
||
# Vocabulary files that were just created | ||
src_vocab: data/wikitext-103-raw/run/example.vocab.src | ||
|
||
# Train on a single GPU | ||
world_size: 1 | ||
gpu_ranks: [0] | ||
|
||
# Where to save the checkpoints | ||
save_model: data/wikitext-103-raw/run/model-lm | ||
save_checkpoint_steps: 50000 | ||
train_steps: 1000000 | ||
valid_steps: 500 | ||
report_every: 100 | ||
tensorboard: true | ||
tensorboard_log_dir: data/wikitext-103-raw/run/tensorboard | ||
|
||
# Model | ||
model_task: lm | ||
encoder_type: transformer_lm | ||
decoder_type: transformer_lm | ||
position_encoding: true | ||
dec_layers: 6 | ||
heads: 8 | ||
rnn_size: 512 | ||
word_vec_size: 512 | ||
transformer_ff: 2048 | ||
dropout_steps: [0] | ||
dropout: [0.1] | ||
attention_dropout: [0.1] | ||
batch_size: 2048 | ||
batch_type: tokens | ||
|
||
model_dtype: "fp32" | ||
optim: "adam" | ||
learning_rate: 2 | ||
warmup_steps: 8000 | ||
decay_method: "noam" | ||
adam_beta2: 0.998 | ||
max_grad_norm: 0 | ||
label_smoothing: 0.1 | ||
param_init: 0 | ||
param_init_glorot: true | ||
normalization: "tokens" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
"""Module defining decoders.""" | ||
from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, \ | ||
StdRNNDecoder | ||
from onmt.decoders.transformer import TransformerDecoder | ||
from onmt.decoders.transformer import TransformerDecoder, TransformerLMDecoder | ||
from onmt.decoders.cnn_decoder import CNNDecoder | ||
|
||
|
||
str2dec = {"rnn": StdRNNDecoder, "ifrnn": InputFeedRNNDecoder, | ||
"cnn": CNNDecoder, "transformer": TransformerDecoder} | ||
"cnn": CNNDecoder, "transformer": TransformerDecoder, | ||
"transformer_lm": TransformerLMDecoder} | ||
|
||
__all__ = ["DecoderBase", "TransformerDecoder", "StdRNNDecoder", "CNNDecoder", | ||
"InputFeedRNNDecoder", "str2dec"] | ||
"InputFeedRNNDecoder", "str2dec", "TransformerLMDecoder"] |
Oops, something went wrong.