Skip to content

Commit

Permalink
Adding ggnn_encoder to support graph-to-sequence processing with atte…
Browse files Browse the repository at this point in the history
…ntion and copy mechanism (#1739)
  • Loading branch information
SteveKommrusch authored Mar 27, 2020
1 parent a4b800d commit 119e8d0
Show file tree
Hide file tree
Showing 10 changed files with 584 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ jobs:
# test nmt preprocessing w/ sharding and training w/copy
- head -50 data/src-val.txt > /tmp/src-val.txt; head -50 data/tgt-val.txt > /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -shard_size 25 -dynamic_dict -save_data /tmp/q -src_vocab_size 1000 -tgt_vocab_size 1000; python train.py -data /tmp/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -copy_attn -train_steps 10 -pool_factor 10 && rm -rf /tmp/q*.pt

# test Graph neural network preprocessing and training
- cp data/ggnnsrc.txt /tmp/src-val.txt; cp data/ggnntgt.txt /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -src_seq_length 1000 -tgt_seq_length 30 -src_vocab data/ggnnsrcvocab.txt -tgt_vocab data/ggnntgtvocab.txt -dynamic_dict -save_data /tmp/q ; python train.py -data /tmp/q -encoder_type ggnn -layers 2 -decoder_type rnn -rnn_size 256 -learning_rate 0.1 -learning_rate_decay 0.8 -global_attention general -batch_size 32 -word_vec_size 256 -bridge -train_steps 10 -src_vocab data/ggnnsrcvocab.txt -n_edge_types 9 -state_dim 256 -n_steps 10 -n_node 64 && rm -rf /tmp/q*.pt

# test im2text preprocessing and training
- head -50 /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head -50 /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt; python preprocess.py -data_type img -src_dir /tmp/im2text/images -train_src /tmp/im2text/src-val-head.txt -train_tgt /tmp/im2text/tgt-val-head.txt -valid_src /tmp/im2text/src-val-head.txt -valid_tgt /tmp/im2text/tgt-val-head.txt -save_data /tmp/im2text/q -tgt_seq_length 100; python train.py -model_type img -data /tmp/im2text/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 -pool_factor 10 && rm -rf /tmp/im2text/q*.pt
# test speech2text preprocessing and training
Expand Down
10 changes: 10 additions & 0 deletions data/ggnnsrc.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
im -m im *m *m B B -m O I O B <EOT> 0 1 2 3 4 5 6 6 8 8 5 3 12 <EOT> 1 3 3 5 4 6 7 8 , 1 11 3 10 4 7 7 9 , 0 4 1 5 2 6 , 1 10 2 7 , 4 8 , 4 9 , 0 2 2 4 , 12 0 , 12 1
-m +m *m *m C C im im +m +m *m *m e e C C *m B I B nm +m +m *m *m im im C C O *m O B B *m B B <EOT> 0 1 2 3 4 5 4 5 6 7 8 9 10 11 10 11 8 9 10 10 2 3 4 5 6 7 8 9 10 8 6 7 8 8 5 7 7 12 <EOT> 0 2 1 3 2 4 3 5 8 10 9 11 10 12 11 13 16 18 21 23 22 24 23 25 24 26 30 32 34 35 , 0 20 1 21 2 6 3 7 8 16 9 17 10 14 11 15 16 19 21 34 22 30 23 31 24 29 30 33 34 36 , 0 4 1 5 6 10 7 11 8 12 9 13 20 24 21 25 22 26 23 27 24 28 , 0 6 1 7 6 16 7 17 8 14 9 15 20 30 21 31 22 29 , 0 22 1 23 2 8 3 9 8 18 21 35 22 32 , 1 34 8 19 21 36 22 33 , 6 8 7 9 20 22 25 27 26 28 , 37 0 , 37 1
+s +s -s -s b b *s *s *s +s +s d d b -s b d c c -s *s d c *s c c -s -s d d /s /s *s *s c c d d <EOT> 0 1 2 3 4 5 4 5 6 7 8 9 10 10 8 9 10 10 6 7 9 11 11 9 11 11 2 3 4 5 4 5 6 7 8 9 6 7 12 <EOT> 0 2 1 3 2 4 3 5 6 8 7 9 8 10 9 11 10 12 14 16 19 20 20 21 23 24 26 28 27 29 30 32 31 33 , 0 26 1 27 2 6 3 7 6 18 7 19 8 14 9 15 10 13 14 17 19 23 20 22 23 25 26 30 27 31 30 36 31 37 , 0 4 1 5 6 10 7 11 8 12 19 21 30 34 31 35 , 0 6 1 7 6 14 7 15 8 13 19 22 , 0 28 1 29 2 8 3 9 7 20 8 16 19 24 26 32 27 33 , 0 30 1 31 2 18 3 19 7 23 8 17 19 25 26 36 27 37 , 32 34 33 35 , 38 0 , 38 1
*v -v -v -v nv *v *v nv x *v x b b e x *v x e -s -v e *v nv *v x b *v x <EOT> 0 1 2 3 4 5 6 7 8 9 11 11 8 7 4 5 7 7 2 3 4 5 7 9 11 11 5 7 12 <EOT> 0 2 1 3 2 4 3 5 5 7 6 8 9 10 15 16 19 21 23 24 , 0 18 1 19 2 14 3 15 5 13 6 12 9 11 15 17 19 26 23 25 , 0 4 1 5 2 6 3 7 4 8 5 9 7 10 19 22 21 23 22 24 , 0 14 1 15 3 13 4 12 7 11 22 25 , 0 20 1 21 3 16 19 27 , 1 26 3 17 , 4 6 7 9 18 20 21 22 22 23 26 27 , 28 0 , 28 1
*v -v ns *v ns d d v -v *v v ns ns d *v *v /s /s d d c c w w <EOT> 0 1 2 3 4 5 6 5 2 3 4 5 7 9 4 5 6 7 8 9 8 9 6 7 12 <EOT> 0 2 1 3 3 5 8 10 9 11 14 16 15 17 16 18 17 19 , 0 8 1 9 3 7 8 14 9 15 14 22 15 23 16 20 17 21 , 0 4 1 5 2 6 9 12 11 13 14 18 15 19 , 1 7 14 20 15 21 , 0 10 1 11 8 16 9 17 , 0 14 1 15 8 22 9 23 , 2 4 4 6 11 12 12 13 , 24 0 , 24 1
-s -s *s *s c c /s 1 0 1 +s +s b b is is e e <EOT> 0 1 2 3 4 5 4 5 6 6 2 3 4 5 4 5 6 7 12 <EOT> 0 2 1 3 2 4 3 5 6 8 10 12 11 13 , 0 10 1 11 2 6 3 7 6 9 10 14 11 15 , 0 4 1 5 , 0 6 1 7 , 0 12 1 13 2 8 10 16 11 17 , 0 14 1 15 2 9 , 14 16 15 17 , 18 0 , 18 1
*m *m +m +m im im I I E E nm -m -m *m *m C *s c e D *m E C E *m *s c e D <EOT> 0 1 2 3 4 5 6 7 4 5 2 3 4 5 6 7 8 10 10 8 6 7 8 8 5 7 9 9 7 12 <EOT> 0 2 1 3 2 4 3 5 11 13 12 14 13 15 14 16 16 17 20 22 24 25 25 26 , 0 10 1 11 2 8 3 9 11 24 12 20 13 21 14 19 16 18 20 23 24 28 25 27 , 0 4 1 5 2 6 3 7 10 14 11 15 12 16 14 17 24 26 , 0 8 1 9 10 20 11 21 12 19 14 18 24 27 , 0 12 1 13 11 25 12 22 , 1 24 11 28 12 23 , 4 6 5 7 10 12 , 29 0 , 29 1
-v -v -v -v +v +v *v *v e e w w *v *v E E -v x x o *v *v is is d d *v *v -m -m O O E E z z +v +v +v *v *v z z d v d v +v v v <EOT> 0 1 2 3 4 5 6 7 8 9 8 9 6 7 8 9 8 9 10 10 4 5 6 7 8 9 6 7 8 9 10 11 10 11 8 9 2 3 4 5 6 7 8 8 6 7 4 5 7 7 12 <EOT> 0 2 1 3 2 4 3 5 4 6 5 7 6 8 7 9 12 14 13 15 16 18 20 22 21 23 26 28 27 29 28 30 29 31 36 38 37 39 38 40 39 41 40 42 47 48 , 0 36 1 37 2 20 3 21 4 12 5 13 6 10 7 11 12 16 13 17 16 19 20 26 21 27 26 34 27 35 28 32 29 33 36 46 37 47 38 44 39 45 40 43 47 49 , 0 4 1 5 2 6 3 7 4 8 5 9 20 24 21 25 26 30 27 31 36 40 37 41 38 42 , 0 20 1 21 2 12 3 13 4 10 5 11 26 32 27 33 36 44 37 45 38 43 , 0 38 1 39 2 22 3 23 4 14 5 15 12 18 20 28 21 29 37 48 , 0 46 1 47 2 26 3 27 4 16 5 17 12 19 20 34 21 35 37 49 , 22 24 23 25 , 50 0 , 50 1
*s *s *s d d d ns *s /s d ns b is +s b a ns *s ns b +s b a <EOT> 0 1 2 3 4 4 2 3 4 5 6 8 6 8 10 10 5 7 9 11 9 11 11 12 <EOT> 0 2 1 3 2 4 7 9 8 10 13 14 17 18 20 21 , 0 6 1 7 2 5 7 16 8 12 13 15 17 20 20 22 , 0 4 6 10 8 11 12 14 16 18 17 19 , 0 5 6 12 12 15 16 20 , 0 8 1 9 7 17 8 13 17 21 , 1 16 17 22 , 6 8 10 11 12 13 16 17 18 19 , 23 0 , 23 1
nv nv -v +v nv -v -v v z v z nv w w <EOT> 0 1 2 3 4 5 6 7 8 8 7 4 5 6 12 <EOT> 2 4 3 5 5 7 6 8 , 2 11 3 12 5 10 6 9 , 0 4 1 5 2 6 3 7 4 8 , 0 11 1 12 3 10 4 9 , 2 13 , , 0 2 1 3 4 6 11 13 , 14 0 , 14 1
126 changes: 126 additions & 0 deletions data/ggnnsrcvocab.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
<unk>
<blank>
<s>
</s>
+s
-s
*s
/s
is
ns
+m
-m
*m
tm
im
nm
+v
-v
*v
nv
Null
a
b
c
d
e
0
1
A
B
C
D
E
O
I
v
w
x
y
z
o
left
right
Cancel
Noop
Double
Commute
Distribleft
Distribright
Factorleft
Factorright
Assocleft
Assocright
Flipleft
Flipright
Transpose
<EOT>
,
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
10 changes: 10 additions & 0 deletions data/ggnntgt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Double Distribright right Noop
Flipright left right left right Noop
left right Assocright left right right Distribleft
Distribright left Distribleft right Distribleft
Distribright left left Double
left right Noop
right Flipleft
left left right right Noop right Assocright
Assocright right right left Flipright
left Flipright left left Flipleft
19 changes: 19 additions & 0 deletions data/ggnntgtvocab.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<unk>
<blank>
<s>
</s>
left
right
Cancel
Noop
Double
Commute
Distribleft
Distribright
Factorleft
Factorright
Assocleft
Assocright
Flipleft
Flipright
Transpose
94 changes: 94 additions & 0 deletions docs/source/ggnn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Gated Graph Sequence Neural Networks

Graph-to-sequence networks allow information represtable as a graph (such as an annotated NLP sentence or computer code structure as an AST) to be connected to a sequence generator to produce output which can benefit from the graph structure of the input.

The training option `-encoder_type ggnn` implements a GGNN (Gated Graph Neural Network) based on github.com/JamesChuanggg/ggnn.pytorch.git which is based on the paper "Gated Graph Sequence Neural Networks" by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel.

The ggnn encoder is used for program equivalence proof generation in the paper <a href="https://arxiv.org/abs/2002.06799">Equivalence of Dataflow Graphs via Rewrite Rules Using a Graph-to-Sequence Neural Model</a>. That paper shows the benefit of the graph-to-sequence model over a sequence-to-sequence model for this problem which can be well represented with graphical input. The integration of the ggnn network into the <a href="https://github.com/OpenNMT/OpenNMT-py/">OpenNMT-py</a> system supports attention on the nodes as well as a copy mechanism.

### Dependencies

* There are no additional dependencies beyond the rnn-to-rnn sequeence2sequence requirements.

### Quick Start

To get started, we provide a toy graph-to-sequence example. We assume that the working directory is `OpenNMT-py` throughout this document.

0) Download the data to a sibling directory.

```
cd ..
git clone https://github.com/SteveKommrusch/OpenNMT-py-ggnn-example
source OpenNMT-py-ggnn-example/env.sh
cd OpenNMT-py
```


1) Preprocess the data.

```
python preprocess.py -train_src $data_path/src-train.txt -train_tgt $data_path/tgt-train.txt -valid_src $data_path/src-val.txt -valid_tgt $data_path/tgt-val.txt -src_seq_length 1000 -tgt_seq_length 30 -src_vocab $data_path/srcvocab.txt -tgt_vocab $data_path/tgtvocab.txt -dynamic_dict -save_data $data_path/final 2>&1 > $data_path/preprocess.out
```

2) Train the model.

```
python train.py -data $data_path/final -encoder_type ggnn -layers 2 -decoder_type rnn -rnn_size 256 -learning_rate 0.1 -start_decay_steps 5000 -learning_rate_decay 0.8 -global_attention general -batch_size 32 -word_vec_size 256 -bridge -train_steps 10000 -gpu_ranks 0 -save_checkpoint_steps 5000 -save_model $data_path/final-model -src_vocab $data_path/srcvocab.txt -n_edge_types 9 -state_dim 256 -n_steps 10 -n_node 64 > $data_path/train.final.out
```

3) Translate the graph of 2 equivalent linear algebra expressions into the axiom list which proves them equivalent.

```
python translate.py -model $data_path/final-model_step_10000.pt -src $data_path/src-test.txt -beam_size 5 -n_best 5 -gpu 0 -output $data_path/pred-test_beam5.txt -dynamic_dict 2>&1 > $data_path/translate5.out
```

### Graph data format

The GGNN implementation leverages the sequence processing and vocabulary
interface of OpenNMT. Each graph is provided on an input line, much like
a sentence is provided on an input line. A graph nearal network input line
includes `sentence tokens`, `feature values`, and `edges` separated by
`<EOT>` (end of tokens) tokens. Below is example of the input for a pair
of algebraic equations structured as a graph:

```
Sentence tokens Feature values Edges
--------------- ------------------ -------------------------------------------------------
- - - 0 a a b b <EOT> 0 1 2 3 4 4 2 3 12 <EOT> 0 2 1 3 2 4 , 0 6 1 7 2 5 , 0 4 , 0 5 , , , , 8 0 , 8 1
```

The equations being represented are `((a - a) - b)` and `(0 - b)`, the
`sentence tokens` of which are provided before the first `<EOT>`. After
the first `<EOT>`, the `features values` are provided. These are extra
flags with information on each node in the graph. In this case, the 8
sentence tokens have feature flags ranging from 0 to 4; the 9th feature
flag defines a 9th node in the graph which does not have sentence token
information, just feature data. Nodes with any non-number flag (such as
`-` or `.`) will not have a feature added. Multiple groups of features
can be provided by using the `,` delimiter between the first and second
'<EOT>' tokens. After the second `<EOT>` token, edge information is provided.
Edge data is given as node pairs, hence `<EOT> 0 2 1 3` indicates that there
are edges from node 0 to node 2 and from node 1 to node 3. The GGNN supports
multiple edge types (which result mathematically in multiple weight matrices
for the model) and the edge types are separated by `,` tokens after the
second `<EOT>` token.

Note that the source vocabulary file needs to include the '<EOT>' token,
the ',' token, and all of the numbers used for feature flags and node
identifiers in the edge list.


### Options

* `-rnn_type (str)`: style of recurrent unit to use, one of [LSTM]
* `-state_dim (int)`: Number of state dimensions in nodes
* `-n_edge_types (int)`: Number of edge types
* `-bidir_edges (bool)`: True if reverse edges should be automatically created
* `-n_node (int)`: Max nodes in graph
* `-bridge_extra_node (bool)`: True indicates only the vector from the 1st extra node (after token listing) should be used for decoder initialization; False indicates all node vectors should be averaged together for decoder initialization
* `-n_steps (int)`: Steps to advance graph encoder for stabilization
* `-src_vocab (int)`: Path to source vocabulary

### Acknowledgement

This gated graph neural network is leveraged from github.com/JamesChuanggg/ggnn.pytorch.git which is based on the paper "Gated Graph Sequence Neural Networks" by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel.
16 changes: 16 additions & 0 deletions docs/source/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@ @article{sennrich2016linguistic
year={2016}
}

@inproceedings{Li2016
author = {Yujia Li and
Daniel Tarlow and
Marc Brockschmidt and
Richard S. Zemel},
title = {Gated Graph Sequence Neural Networks},
booktitle = {4th International Conference on Learning Representations, {ICLR} 2016,
San Juan, Puerto Rico, May 2-4, 2016, Conference Track Proceedings},
year = {2016},
crossref = {DBLP:conf/iclr/2016},
url = {http://arxiv.org/abs/1511.05493},
timestamp = {Thu, 25 Jul 2019 14:25:40 +0200},
biburl = {https://dblp.org/rec/journals/corr/LiTBZ15.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}

@inproceedings{Bahdanau2015,
archivePrefix = {arXiv},
arxivId = {1409.0473},
Expand Down
7 changes: 4 additions & 3 deletions onmt/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""Module defining encoders."""
from onmt.encoders.encoder import EncoderBase
from onmt.encoders.transformer import TransformerEncoder
from onmt.encoders.ggnn_encoder import GGNNEncoder
from onmt.encoders.rnn_encoder import RNNEncoder
from onmt.encoders.cnn_encoder import CNNEncoder
from onmt.encoders.mean_encoder import MeanEncoder
from onmt.encoders.audio_encoder import AudioEncoder
from onmt.encoders.image_encoder import ImageEncoder


str2enc = {"rnn": RNNEncoder, "brnn": RNNEncoder, "cnn": CNNEncoder,
"transformer": TransformerEncoder, "img": ImageEncoder,
"audio": AudioEncoder, "mean": MeanEncoder}
str2enc = {"ggnn": GGNNEncoder, "rnn": RNNEncoder, "brnn": RNNEncoder,
"cnn": CNNEncoder, "transformer": TransformerEncoder,
"img": ImageEncoder, "audio": AudioEncoder, "mean": MeanEncoder}

__all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder",
"MeanEncoder", "str2enc"]
Loading

0 comments on commit 119e8d0

Please sign in to comment.