Skip to content

Commit

Permalink
more docs update
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Nov 3, 2022
1 parent 24c0c4f commit 3f86f7d
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 32 deletions.
31 changes: 25 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,34 @@ If you want to optimize the training performance:

### Breaking changes

A few features were dropped between v1 and v2:
Changes between v2 and v3:

Options removed:
queue_size, pool_factor are no longer needed. Only adjust the bucket_size to the number of examples to be loaded by each num_workers of the pytorch Dataloader.

New options:
num_workers: number of workers for each process. If you run on one GPU the recommended value is 4. If you run on more than 1 GPU, the recommended value is 2
add_qkvbias: default is false. However old model trained with v2 will be set at true. The original transformer paper used no bias for the Q/K/V nn.Linear of the multihed attention module.

Options renamed:
rnn_size => hidden_size
enc_rnn_size => enc_hid_size
dec_rnn_size => dec_hid_size

Note: tools/convertv2_v3.py will modify these options stored in the checkpoint to make things compatible with v3.0

Inference:
The translator will use the same dynamic_iterator as the trainer.
The new default for inference is "length_penalty=avg" which will provide better BLEU scores in most cases (and comparable to other toolkits defaults)



Reminder: a few features were dropped between v1 and v2:

- audio, image and video inputs;

For any user that still need these features, the previous codebase will be retained as `legacy` in a separate branch. It will no longer receive extensive development from the core team but PRs may still be accepted.

- For inference, we default to length_penalty: avg which usually gives better BLEU and is comparable to other toolkits.

Feel free to check it out and let us know what you think of the new paradigm!

----
Expand All @@ -79,7 +99,7 @@ Table of Contents

OpenNMT-py requires:

- Python >= 3.6
- Python >= 3.7
- PyTorch >= 1.9.0

Install `OpenNMT-py` from `pip`:
Expand All @@ -104,8 +124,7 @@ pip install -r requirements.opt.txt

## Features

- :warning: **New in OpenNMT-py 2.0**: [On the fly data processing]([here](https://opennmt.net/OpenNMT-py/FAQ.html#what-are-the-readily-available-on-the-fly-data-transforms).)

- [On the fly data processing]([here](https://opennmt.net/OpenNMT-py/FAQ.html#what-are-the-readily-available-on-the-fly-data-transforms).)
- [Encoder-decoder models with multiple RNN cells (LSTM, GRU) and attention types (Luong, Bahdanau)](https://opennmt.net/OpenNMT-py/options/train.html#model-encoder-decoder)
- [Transformer models](https://opennmt.net/OpenNMT-py/FAQ.html#how-do-i-use-the-transformer-model)
- [Copy and Coverage Attention](https://opennmt.net/OpenNMT-py/options/train.html#model-attention)
Expand Down
2 changes: 1 addition & 1 deletion data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

> python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000
> python train.py -data data/data -save_model /n/rush_lab/data/tmp_ -world_size 1 -gpu_ranks 0 -rnn_size 100 -word_vec_size 50 -layers 1 -train_steps 100 -optim adam -learning_rate 0.001
> python train.py -data data/data -save_model /n/rush_lab/data/tmp_ -world_size 1 -gpu_ranks 0 -hidden_size 100 -word_vec_size 50 -layers 1 -train_steps 100 -optim adam -learning_rate 0.001
2 changes: 1 addition & 1 deletion docs/source/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ Bear in mind that your models must share the same target vocabulary.

This is naturally embedded in the data configuration format introduced in OpenNMT-py 2.0. Each entry of the `data` configuration will have its own *weight*. When building batches, we'll sequentially take *weight* example from each corpus.

**Note**: don't worry about batch homogeneity/heterogeneity, the pooling mechanism is here for that reason. Instead of building batches one at a time, we will load `pool_factor` of batches worth of examples, sort them by length, build batches and then yield them in a random order.
**Note**: don't worry about batch homogeneity/heterogeneity, the bucketing mechanism is here for that reason. Instead of building batches one at a time, we will load `bucket_size` examples, sort them by length, build batches and then yield them in a random order.

### Example

Expand Down
6 changes: 3 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@

# General information about the project.
project = 'OpenNMT-py'
copyright = '2017, srush'
author = 'srush'
copyright = '2017, OpenNMT'
author = 'team'

# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
Expand Down Expand Up @@ -176,7 +176,7 @@
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'OpenNMT-py.tex', 'OpenNMT-py Documentation',
'srush', 'manual'),
'team', 'manual'),
]


Expand Down
10 changes: 5 additions & 5 deletions docs/source/examples/Library.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# How to use OpenNMT-py as a Library"
"# How to use OpenNMT-py as a Library - Valid for v1 and v2 - v3 not yet available"
]
},
{
Expand Down Expand Up @@ -475,20 +475,20 @@
"outputs": [],
"source": [
"emb_size = 100\n",
"rnn_size = 500\n",
"hidden_size = 500\n",
"# Specify the core model.\n",
"\n",
"encoder_embeddings = onmt.modules.Embeddings(emb_size, len(src_vocab),\n",
" word_padding_idx=src_padding)\n",
"\n",
"encoder = onmt.encoders.RNNEncoder(hidden_size=rnn_size, num_layers=1,\n",
"encoder = onmt.encoders.RNNEncoder(hidden_size=hidden_size, num_layers=1,\n",
" rnn_type=\"LSTM\", bidirectional=True,\n",
" embeddings=encoder_embeddings)\n",
"\n",
"decoder_embeddings = onmt.modules.Embeddings(emb_size, len(tgt_vocab),\n",
" word_padding_idx=tgt_padding)\n",
"decoder = onmt.decoders.decoder.InputFeedRNNDecoder(\n",
" hidden_size=rnn_size, num_layers=1, bidirectional_encoder=True, \n",
" hidden_size=hidden_size, num_layers=1, bidirectional_encoder=True, \n",
" rnn_type=\"LSTM\", embeddings=decoder_embeddings)\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
Expand All @@ -497,7 +497,7 @@
"\n",
"# Specify the tgt word generator and loss computation module\n",
"model.generator = nn.Sequential(\n",
" nn.Linear(rnn_size, len(tgt_vocab)),\n",
" nn.Linear(hidden_size, len(tgt_vocab)),\n",
" nn.LogSoftmax(dim=-1)).to(device)\n",
"\n",
"loss = onmt.utils.loss.NMTLossCompute(\n",
Expand Down
6 changes: 3 additions & 3 deletions examples/onmt.train.fp16.transformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ report_every: 100
train_steps: 100000
valid_steps: 4000

bucket_size: 32768
bucket_size: 262144
world_size: 2
gpu_ranks: [0, 1]
num_workers: 4
num_workers: 2
batch_type: "tokens"
batch_size: 4096
valid_batch_size: 8
Expand All @@ -92,7 +92,7 @@ decoder_type: transformer
enc_layers: 6
dec_layers: 6
heads: 8
rnn_size: 512
hidden_size: 512
word_vec_size: 512
transformer_ff: 2048
dropout_steps: [0]
Expand Down
5 changes: 2 additions & 3 deletions examples/wmt14_en_de.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,13 @@ valid_steps: 5000

# Batching
bucket_size: 32768
pool_factor: 8192
world_size: 2
gpu_ranks: [0, 1]
num_workers: 4
batch_type: "tokens"
batch_size: 4096
valid_batch_size: 16
batch_size_multiple: 1
valid_batch_size: 2048
batch_size_multiple: 8
max_generator_batches: 0
accum_count: [3]
accum_steps: [0]
Expand Down
10 changes: 5 additions & 5 deletions onmt/tests/rebuild_test_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ $my_python train.py \
-src_vocab_size 1000 -tgt_vocab_size 1000 \
-save_model tmp -world_size 1 -gpu_ranks 0 \
-rnn_type LSTM -input_feed 0 \
-rnn_size 256 -word_vec_size 256 \
-hidden_size 256 -word_vec_size 256 \
-layers 1 -train_steps 10000 \
-optim adam -learning_rate 0.001
# -truncated_decoder 5
Expand All @@ -35,7 +35,7 @@ $my_python train.py \
-src_vocab_size 1000 -tgt_vocab_size 1000 \
-save_model /tmp/tmp -world_size 1 -gpu_ranks 0 \
-encoder_type cnn -decoder_type cnn \
-rnn_size 256 -word_vec_size 256 \
-hidden_size 256 -word_vec_size 256 \
-layers 2 -train_steps 10000 \
-optim adam -learning_rate 0.001

Expand All @@ -54,7 +54,7 @@ if false; then
$my_python train.py \
-config data/morph_data.yaml -src_vocab data/morph_data.vocab.src -tgt_vocab data/morph_data.vocab.tgt \
-save_model tmp -world_size 1 -gpu_ranks 0 \
-rnn_size 400 -word_vec_size 100 \
-hidden_size 400 -word_vec_size 100 \
-layers 1 -train_steps 8000 \
-optim adam -learning_rate 0.001

Expand All @@ -76,7 +76,7 @@ $my_python train.py \
-config data/data.yaml -src_vocab data/data.vocab.src -tgt_vocab data/data.vocab.tgt \
-save_model /tmp/tmp \
-batch_type tokens -batch_size 8 -accum_count 4 \
-layers 1 -rnn_size 16 -word_vec_size 16 \
-layers 1 -hidden_size 16 -word_vec_size 16 \
-encoder_type transformer -decoder_type transformer \
-share_embedding -share_vocab \
-train_steps 1000 -world_size 1 -gpu_ranks 0 \
Expand Down Expand Up @@ -108,7 +108,7 @@ $my_python build_vocab.py \
-overwrite true

$my_python train.py -config data/lm_data.yaml -save_model /tmp/tmp \
-accum_count 2 -dec_layers 2 -rnn_size 64 -word_vec_size 64 -batch_size 256 \
-accum_count 2 -dec_layers 2 -hidden_size 64 -word_vec_size 64 -batch_size 256 \
-encoder_type transformer_lm -decoder_type transformer_lm -share_embedding \
-train_steps 2000 -max_generator_batches 4 -dropout 0.1 -normalization tokens \
-share_vocab -transformer_ff 256 -max_grad_norm 0 -optim adam -decay_method noam \
Expand Down
10 changes: 5 additions & 5 deletions onmt/tests/test_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ lstm(){
$PYTHON_BIN train.py -data "$DATA_PATH" \
-save_model "$MODEL_PATH" \
-gpuid $GPUID \
-rnn_size 512 \
-hidden_size 512 \
-word_vec_size 512 \
-layers 1 \
-train_steps 10000 \
Expand Down Expand Up @@ -132,7 +132,7 @@ sru(){
$PYTHON_BIN train.py -data "$DATA_PATH" \
-save_model "$MODEL_PATH" \
-gpuid $GPUID \
-rnn_size 512 \
-hidden_size 512 \
-word_vec_size 512 \
-layers 1 \
-train_steps 10000 \
Expand All @@ -158,7 +158,7 @@ cnn(){
$PYTHON_BIN train.py -data "$DATA_PATH" \
-save_model "$MODEL_PATH" \
-gpuid $GPUID \
-rnn_size 256 \
-hidden_size 256 \
-word_vec_size 256 \
-layers 2 \
-train_steps 10000 \
Expand Down Expand Up @@ -186,7 +186,7 @@ morph(){
$PYTHON_BIN train.py -data "$DATA_DIR"/morph/data \
-save_model "$MODEL_PATH" \
-gpuid $GPUID \
-rnn_size 400 \
-hidden_size 400 \
-word_vec_size 100 \
-layers 1 \
-train_steps 10000 \
Expand Down Expand Up @@ -220,7 +220,7 @@ transformer(){
-batch_size 1024 \
-accum_count 4 \
-layers 1 \
-rnn_size 256 \
-hidden_size 256 \
-word_vec_size 256 \
-encoder_type transformer \
-decoder_type transformer \
Expand Down

0 comments on commit 3f86f7d

Please sign in to comment.