Skip to content

Commit

Permalink
Restore src_features for v3.0 (#2308)
Browse files Browse the repository at this point in the history
* Restored src_features for v3
  • Loading branch information
anderleich authored Feb 10, 2023
1 parent 563d207 commit 62c96cc
Show file tree
Hide file tree
Showing 24 changed files with 508 additions and 336 deletions.
22 changes: 17 additions & 5 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ jobs:
-save_data /tmp/onmt_feat \
-src_vocab /tmp/onmt_feat.vocab.src \
-tgt_vocab /tmp/onmt_feat.vocab.tgt \
-src_feats_vocab '{"feat0": "/tmp/onmt_feat.vocab.feat0"}' \
-n_sample -1 \
&& rm -rf /tmp/sample
- name: Test field/transform dump
Expand Down Expand Up @@ -259,21 +258,34 @@ jobs:
-config data/features_data.yaml \
-src_vocab /tmp/onmt_feat.vocab.src \
-tgt_vocab /tmp/onmt_feat.vocab.tgt \
-src_feats_vocab '{"feat0": "/tmp/onmt_feat.vocab.feat0"}' \
-src_vocab_size 1000 -tgt_vocab_size 1000 \
-hidden_size 2 -batch_size 10 \
-num_workers 0 -bucket_size 1024 \
-word_vec_size 5 -hidden_size 10 \
-report_every 5 -train_steps 10 \
-save_model /tmp/onmt.model \
-save_checkpoint_steps 10
- name: Testing training with features and dynamic scoring
run: |
python onmt/bin/train.py \
-config data/features_data.yaml \
-src_vocab /tmp/onmt_feat.vocab.src \
-tgt_vocab /tmp/onmt_feat.vocab.tgt \
-src_vocab_size 1000 -tgt_vocab_size 1000 \
-hidden_size 2 -batch_size 10 \
-word_vec_size 5 -hidden_size 10 \
-num_workers 0 -bucket_size 1024 \
-report_every 5 -train_steps 10 \
-train_metrics "BLEU" "TER" \
-valid_metrics "BLEU" "TER" \
-save_model /tmp/onmt.model \
-save_checkpoint_steps 10
- name: Testing translation with features
run: |
python translate.py \
-model /tmp/onmt.model_step_10.pt \
-src data/data_features/src-test.txt \
-src_feats "{'feat0': 'data/data_features/src-test.feat0'}" \
-verbose
-src data/data_features/src-test-with-feats.txt \
-n_src_feats 1 -verbose
- name: Test RNN translation
run: |
head data/src-test.txt > /tmp/src-test.txt
Expand Down
1 change: 1 addition & 0 deletions data/data_features/src-test-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she│C is│B a│A hard-working.│B
1 change: 0 additions & 1 deletion data/data_features/src-test.feat0

This file was deleted.

3 changes: 3 additions & 0 deletions data/data_features/src-train-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
however,│A according│A to│A the│A logs,│B she│A is│A a│A hard-working.│C
however,│A according│B to│C the│D logs,│E
she│C is│B a│A hard-working.│B
3 changes: 0 additions & 3 deletions data/data_features/src-train.feat0

This file was deleted.

1 change: 1 addition & 0 deletions data/data_features/src-val-with-feats.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
she│C is│B a│A hard-working.│B
1 change: 0 additions & 1 deletion data/data_features/src-val.feat0

This file was deleted.

16 changes: 12 additions & 4 deletions data/features_data.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@

# Corpus opts:
data:
corpus_1:
path_src: data/data_features/src-train-with-feats.txt
path_tgt: data/data_features/tgt-train.txt
transforms: [inferfeats]
corpus_2:
path_src: data/data_features/src-train.txt
path_tgt: data/data_features/tgt-train.txt
src_feats:
feat0: data/data_features/src-train.feat0
transforms: [filterfeats, inferfeats]
transforms: [inferfeats]
valid:
path_src: data/data_features/src-val.txt
path_src: data/data_features/src-val-with-feats.txt
path_tgt: data/data_features/tgt-val.txt
transforms: [inferfeats]

# # Feats options
n_src_feats: 1
src_feats_defaults: "0"
75 changes: 28 additions & 47 deletions docs/source/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -620,39 +620,34 @@ Training options to perform vocabulary update are:

## How can I use source word features?

Extra information can be added to the words in the source sentences by defining word features.
Additional word-level information can be incorporated into the model by defining word features in the source sentence.

Features should be defined in a separate file using blank spaces as a separator and with each row corresponding to a source sentence. An example of the input files:
Word features must be appended to the actual textual data by using the special character │ as a feature separator. For instance:

data.src
```
however, according to the logs, she is hard-working.
however│C ■,│N according│L to│L the│L logs│L ■,│N she│L is│L hard-working│L ■.│N
```

feat.txt
Prior tokenization is not necessary, features will be inferred by using the `FeatInferTransform` transform if tokenization has been applied. For instace:

```
A C C C C A A B
SRC: however,│C according│L to│L the│L logs,│L she│L is│L hard-working.│L
TOKENIZED SRC: however ■, according to the logs ■, she is hard-working ■.
RESULT: however│C ■,│C according│L to│L the│L logs│L ■,│L she│L is│L hard│L ■-■│L working│L ■.│L
```

Prior tokenization is not necessary, features will be inferred by using the `FeatInferTransform` transform if tokenization has been applied.
**Options**
- `-n_src_feats`: the expected number of source features per token.
- `-src_feats_defaults` (optional): provides default values for features. This can be really useful when mixing task specific data (with features) with general data which has not been annotated.

No previous tokenization:
```
SRC: this is a test.
FEATS: A A A B
TOKENIZED SRC: this is a test ■.
RESULT: A A A B <null>
```
For the Transformer architecture make sure the following options are appropriately set:

Previously tokenized:
```
SRC: this is a test ■.
FEATS: A A A B A
RESULT: A A A B A
```
- `src_word_vec_size` and `tgt_word_vec_size` or `word_vec_size`
- `feat_merge`: how to handle features vecs
- `feat_vec_size` or maybe `feat_vec_exponent`

**Notes**
- `FilterFeatsTransform` and `FeatInferTransform` are required in order to ensure the functionality.
- `FeatInferTransform` transform is required in order to ensure the functionality.
- Not possible to do shared embeddings (at least with `feat_merge: concat` method)

Sample config file:
Expand All @@ -662,50 +657,36 @@ data:
dummy:
path_src: data/train/data.src
path_tgt: data/train/data.tgt
src_feats:
feat_0: data/train/data.src.feat_0
feat_1: data/train/data.src.feat_1
transforms: [filterfeats, onmt_tokenize, inferfeats, filtertoolong]
transforms: [onmt_tokenize, inferfeats, filtertoolong]
weight: 1
valid:
path_src: data/valid/data.src
path_tgt: data/valid/data.tgt
src_feats:
feat_0: data/valid/data.src.feat_0
feat_1: data/valid/data.src.feat_1
transforms: [filterfeats, onmt_tokenize, inferfeats]
transforms: [onmt_tokenize, inferfeats]
# Transform options
reversible_tokenization: "joiner"
prior_tokenization: true
# Vocab opts
src_vocab: exp/data.vocab.src
tgt_vocab: exp/data.vocab.tgt
src_feats_vocab:
feat_0: exp/data.vocab.feat_0
feat_1: exp/data.vocab.feat_1
# Features options
n_src_feats: 2
src_feats_defaults: "0│1"
feat_merge: "sum"
```

During inference you can pass features by using the `--src_feats` argument. `src_feats` is expected to be a Python like dict, mapping feature names with their data file.
To allow source features in the server add the following parameters in the server's config file:

```
{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}
```
**Important note!** During inference, input sentence is expected to be tokenized. Therefore feature inferring should be handled prior to running the translate command. Example:
```bash
python translate.py -model model_step_10.pt -src ../data.txt.tok -output ../data.out --src_feats "{'feat_0': '../data.txt.feats0', 'feat_1': '../data.txt.feats1'}"
"features": {
"n_src_feats": 2,
"src_feats_defaults": "0│1",
"reversible_tokenization": "joiner"
}
```

When using the Transformer architecture make sure the following options are appropriately set:

- `src_word_vec_size` and `tgt_word_vec_size` or `word_vec_size`
- `feat_merge`: how to handle features vecs
- `feat_vec_size` and maybe `feat_vec_exponent`

## How can I set up a translation server ?
A REST server was implemented to serve OpenNMT-py models. A discussion is opened on the OpenNMT forum: [discussion link](https://forum.opennmt.net/t/simple-opennmt-py-rest-server/1392).

Expand Down
54 changes: 24 additions & 30 deletions onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from onmt.utils.parse import ArgumentParser
from onmt.opts import dynamic_prepare_opts
from onmt.inputters.text_corpus import build_corpora_iters, get_corpora
from onmt.inputters.text_utils import process
from onmt.inputters.text_utils import process, append_features_to_text
from onmt.transforms import make_transforms, get_transforms_cls
from onmt.constants import CorpusName, CorpusTask
from collections import Counter, defaultdict
from collections import Counter
import multiprocessing as mp


Expand Down Expand Up @@ -40,21 +40,11 @@ def write_files_from_queues(sample_path, queues):
break


# Just for debugging purposes
# It appends features to subwords when dumping to file
def append_features_to_example(example, features):
ex_toks = example.split(' ')
feat_toks = features.split(' ')
toks = [f"{subword}{feat}" for subword, feat in
zip(ex_toks, feat_toks)]
return " ".join(toks)


def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
"""Build vocab on (strided) subpart of the data."""
sub_counter_src = Counter()
sub_counter_tgt = Counter()
sub_counter_src_feats = defaultdict(Counter)
sub_counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
datasets_iterables = build_corpora_iters(
corpora, transforms, opts.data,
skip_empty_level=opts.skip_empty_level,
Expand All @@ -70,19 +60,22 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
continue
src_line, tgt_line = (maybe_example['src']['src'],
maybe_example['tgt']['tgt'])
src_line_pretty = src_line
for feat_name, feat_line in maybe_example["src"].items():
if feat_name not in ["src", "src_original"]:
sub_counter_src_feats[feat_name].update(
feat_line.split(' '))
if opts.dump_samples:
src_line_pretty = append_features_to_example(
src_line_pretty, feat_line)
sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))

if 'feats' in maybe_example['src']:
src_feats_lines = maybe_example['src']['feats']
for i in range(opts.n_src_feats):
sub_counter_src_feats[i].update(
src_feats_lines[i].split(' '))
else:
src_feats_lines = []

if opts.dump_samples:
src_pretty_line = append_features_to_text(
src_line, src_feats_lines)
build_sub_vocab.queues[c_name][offset].put(
(i, src_line_pretty, tgt_line))
(i, src_pretty_line, tgt_line))
if n_sample > 0 and ((i+1) * stride + offset) >= n_sample:
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
Expand Down Expand Up @@ -113,7 +106,7 @@ def build_vocab(opts, transforms, n_sample=3):
corpora = get_corpora(opts, task=CorpusTask.TRAIN)
counter_src = Counter()
counter_tgt = Counter()
counter_src_feats = defaultdict(Counter)
counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
from functools import partial
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
for i in range(opts.num_threads)]
Expand All @@ -134,7 +127,8 @@ def build_vocab(opts, transforms, n_sample=3):
func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
counter_src_feats.update(sub_counter_src_feats)
for i in range(opts.n_src_feats):
counter_src_feats[i].update(sub_counter_src_feats[i])
if opts.dump_samples:
write_process.join()
return counter_src, counter_tgt, counter_src_feats
Expand Down Expand Up @@ -166,10 +160,10 @@ def build_vocab_main(opts):
src_counter, tgt_counter, src_feats_counter = build_vocab(
opts, transforms, n_sample=opts.n_sample)

logger.info(f"Counters src:{len(src_counter)}")
logger.info(f"Counters tgt:{len(tgt_counter)}")
for feat_name, feat_counter in src_feats_counter.items():
logger.info(f"Counters {feat_name}:{len(feat_counter)}")
logger.info(f"Counters src: {len(src_counter)}")
logger.info(f"Counters tgt: {len(tgt_counter)}")
for i, feat_counter in enumerate(src_feats_counter):
logger.info(f"Counters src feat_{i}: {len(feat_counter)}")

def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
Expand All @@ -186,8 +180,8 @@ def save_counter(counter, save_path):
save_counter(src_counter, opts.src_vocab)
save_counter(tgt_counter, opts.tgt_vocab)

for k, v in src_feats_counter.items():
save_counter(v, opts.src_feats_vocab[k])
for i, c in enumerate(src_feats_counter):
save_counter(c, f"{opts.src_vocab}_feat{i}")


def _get_parser():
Expand Down
32 changes: 14 additions & 18 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@ def build_vocab(opt, specials):
""" Build vocabs dict to be stored in the checkpoint
based on vocab files having each line [token, count]
Args:
opt: src_vocab, tgt_vocab, src_feats_vocab
opt: src_vocab, tgt_vocab, n_src_feats
Return:
vocabs: {'src': pyonmttok.Vocab, 'tgt': pyonmttok.Vocab,
'src_feats' : {'feat0': pyonmttok.Vocab,
'feat1': pyonmttok.Vocab, ...},
'src_feats' : [pyonmttok.Vocab, ...]},
'data_task': seq2seq or lm
}
"""
Expand Down Expand Up @@ -85,10 +84,10 @@ def _pad_vocab_to_multiple(vocab, multiple):
opt.vocab_size_multiple)
vocabs['tgt'] = tgt_vocab

if opt.src_feats_vocab:
src_feats = {}
for feat_name, filepath in opt.src_feats_vocab.items():
src_f_vocab = _read_vocab_file(filepath, 1)
if opt.n_src_feats > 0:
src_feats_vocabs = []
for i in range(opt.n_src_feats):
src_f_vocab = _read_vocab_file(f"{opt.src_vocab}_feat{i}", 1)
src_f_vocab = pyonmttok.build_vocab_from_tokens(
src_f_vocab,
maximum_size=0,
Expand All @@ -101,8 +100,8 @@ def _pad_vocab_to_multiple(vocab, multiple):
if opt.vocab_size_multiple > 1:
src_f_vocab = _pad_vocab_to_multiple(src_f_vocab,
opt.vocab_size_multiple)
src_feats[feat_name] = src_f_vocab
vocabs['src_feats'] = src_feats
src_feats_vocabs.append(src_f_vocab)
vocabs["src_feats"] = src_feats_vocabs

vocabs['data_task'] = opt.data_task

Expand Down Expand Up @@ -146,10 +145,8 @@ def vocabs_to_dict(vocabs):
vocabs_dict['src'] = vocabs['src'].ids_to_tokens
vocabs_dict['tgt'] = vocabs['tgt'].ids_to_tokens
if 'src_feats' in vocabs.keys():
vocabs_dict['src_feats'] = {}
for feat in vocabs['src_feats'].keys():
vocabs_dict['src_feats'][feat] = \
vocabs['src_feats'][feat].ids_to_tokens
vocabs_dict['src_feats'] = [feat_vocab.ids_to_tokens
for feat_vocab in vocabs['src_feats']]
vocabs_dict['data_task'] = vocabs['data_task']
return vocabs_dict

Expand All @@ -167,9 +164,8 @@ def dict_to_vocabs(vocabs_dict):
else:
vocabs['tgt'] = pyonmttok.build_vocab_from_tokens(vocabs_dict['tgt'])
if 'src_feats' in vocabs_dict.keys():
vocabs['src_feats'] = {}
for feat in vocabs_dict['src_feats'].keys():
vocabs['src_feats'][feat] = \
pyonmttok.build_vocab_from_tokens(
vocabs_dict['src_feats'][feat])
vocabs['src_feats'] = []
for feat_vocab in vocabs_dict['src_feats']:
vocabs['src_feats'].append(
pyonmttok.build_vocab_from_tokens(feat_vocab))
return vocabs
Loading

0 comments on commit 62c96cc

Please sign in to comment.