-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding ggnn_encoder to support graph-to-sequence processing with atte…
…ntion and copy mechanism (#1739)
- Loading branch information
1 parent
a4b800d
commit 119e8d0
Showing
10 changed files
with
584 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
im -m im *m *m B B -m O I O B <EOT> 0 1 2 3 4 5 6 6 8 8 5 3 12 <EOT> 1 3 3 5 4 6 7 8 , 1 11 3 10 4 7 7 9 , 0 4 1 5 2 6 , 1 10 2 7 , 4 8 , 4 9 , 0 2 2 4 , 12 0 , 12 1 | ||
-m +m *m *m C C im im +m +m *m *m e e C C *m B I B nm +m +m *m *m im im C C O *m O B B *m B B <EOT> 0 1 2 3 4 5 4 5 6 7 8 9 10 11 10 11 8 9 10 10 2 3 4 5 6 7 8 9 10 8 6 7 8 8 5 7 7 12 <EOT> 0 2 1 3 2 4 3 5 8 10 9 11 10 12 11 13 16 18 21 23 22 24 23 25 24 26 30 32 34 35 , 0 20 1 21 2 6 3 7 8 16 9 17 10 14 11 15 16 19 21 34 22 30 23 31 24 29 30 33 34 36 , 0 4 1 5 6 10 7 11 8 12 9 13 20 24 21 25 22 26 23 27 24 28 , 0 6 1 7 6 16 7 17 8 14 9 15 20 30 21 31 22 29 , 0 22 1 23 2 8 3 9 8 18 21 35 22 32 , 1 34 8 19 21 36 22 33 , 6 8 7 9 20 22 25 27 26 28 , 37 0 , 37 1 | ||
+s +s -s -s b b *s *s *s +s +s d d b -s b d c c -s *s d c *s c c -s -s d d /s /s *s *s c c d d <EOT> 0 1 2 3 4 5 4 5 6 7 8 9 10 10 8 9 10 10 6 7 9 11 11 9 11 11 2 3 4 5 4 5 6 7 8 9 6 7 12 <EOT> 0 2 1 3 2 4 3 5 6 8 7 9 8 10 9 11 10 12 14 16 19 20 20 21 23 24 26 28 27 29 30 32 31 33 , 0 26 1 27 2 6 3 7 6 18 7 19 8 14 9 15 10 13 14 17 19 23 20 22 23 25 26 30 27 31 30 36 31 37 , 0 4 1 5 6 10 7 11 8 12 19 21 30 34 31 35 , 0 6 1 7 6 14 7 15 8 13 19 22 , 0 28 1 29 2 8 3 9 7 20 8 16 19 24 26 32 27 33 , 0 30 1 31 2 18 3 19 7 23 8 17 19 25 26 36 27 37 , 32 34 33 35 , 38 0 , 38 1 | ||
*v -v -v -v nv *v *v nv x *v x b b e x *v x e -s -v e *v nv *v x b *v x <EOT> 0 1 2 3 4 5 6 7 8 9 11 11 8 7 4 5 7 7 2 3 4 5 7 9 11 11 5 7 12 <EOT> 0 2 1 3 2 4 3 5 5 7 6 8 9 10 15 16 19 21 23 24 , 0 18 1 19 2 14 3 15 5 13 6 12 9 11 15 17 19 26 23 25 , 0 4 1 5 2 6 3 7 4 8 5 9 7 10 19 22 21 23 22 24 , 0 14 1 15 3 13 4 12 7 11 22 25 , 0 20 1 21 3 16 19 27 , 1 26 3 17 , 4 6 7 9 18 20 21 22 22 23 26 27 , 28 0 , 28 1 | ||
*v -v ns *v ns d d v -v *v v ns ns d *v *v /s /s d d c c w w <EOT> 0 1 2 3 4 5 6 5 2 3 4 5 7 9 4 5 6 7 8 9 8 9 6 7 12 <EOT> 0 2 1 3 3 5 8 10 9 11 14 16 15 17 16 18 17 19 , 0 8 1 9 3 7 8 14 9 15 14 22 15 23 16 20 17 21 , 0 4 1 5 2 6 9 12 11 13 14 18 15 19 , 1 7 14 20 15 21 , 0 10 1 11 8 16 9 17 , 0 14 1 15 8 22 9 23 , 2 4 4 6 11 12 12 13 , 24 0 , 24 1 | ||
-s -s *s *s c c /s 1 0 1 +s +s b b is is e e <EOT> 0 1 2 3 4 5 4 5 6 6 2 3 4 5 4 5 6 7 12 <EOT> 0 2 1 3 2 4 3 5 6 8 10 12 11 13 , 0 10 1 11 2 6 3 7 6 9 10 14 11 15 , 0 4 1 5 , 0 6 1 7 , 0 12 1 13 2 8 10 16 11 17 , 0 14 1 15 2 9 , 14 16 15 17 , 18 0 , 18 1 | ||
*m *m +m +m im im I I E E nm -m -m *m *m C *s c e D *m E C E *m *s c e D <EOT> 0 1 2 3 4 5 6 7 4 5 2 3 4 5 6 7 8 10 10 8 6 7 8 8 5 7 9 9 7 12 <EOT> 0 2 1 3 2 4 3 5 11 13 12 14 13 15 14 16 16 17 20 22 24 25 25 26 , 0 10 1 11 2 8 3 9 11 24 12 20 13 21 14 19 16 18 20 23 24 28 25 27 , 0 4 1 5 2 6 3 7 10 14 11 15 12 16 14 17 24 26 , 0 8 1 9 10 20 11 21 12 19 14 18 24 27 , 0 12 1 13 11 25 12 22 , 1 24 11 28 12 23 , 4 6 5 7 10 12 , 29 0 , 29 1 | ||
-v -v -v -v +v +v *v *v e e w w *v *v E E -v x x o *v *v is is d d *v *v -m -m O O E E z z +v +v +v *v *v z z d v d v +v v v <EOT> 0 1 2 3 4 5 6 7 8 9 8 9 6 7 8 9 8 9 10 10 4 5 6 7 8 9 6 7 8 9 10 11 10 11 8 9 2 3 4 5 6 7 8 8 6 7 4 5 7 7 12 <EOT> 0 2 1 3 2 4 3 5 4 6 5 7 6 8 7 9 12 14 13 15 16 18 20 22 21 23 26 28 27 29 28 30 29 31 36 38 37 39 38 40 39 41 40 42 47 48 , 0 36 1 37 2 20 3 21 4 12 5 13 6 10 7 11 12 16 13 17 16 19 20 26 21 27 26 34 27 35 28 32 29 33 36 46 37 47 38 44 39 45 40 43 47 49 , 0 4 1 5 2 6 3 7 4 8 5 9 20 24 21 25 26 30 27 31 36 40 37 41 38 42 , 0 20 1 21 2 12 3 13 4 10 5 11 26 32 27 33 36 44 37 45 38 43 , 0 38 1 39 2 22 3 23 4 14 5 15 12 18 20 28 21 29 37 48 , 0 46 1 47 2 26 3 27 4 16 5 17 12 19 20 34 21 35 37 49 , 22 24 23 25 , 50 0 , 50 1 | ||
*s *s *s d d d ns *s /s d ns b is +s b a ns *s ns b +s b a <EOT> 0 1 2 3 4 4 2 3 4 5 6 8 6 8 10 10 5 7 9 11 9 11 11 12 <EOT> 0 2 1 3 2 4 7 9 8 10 13 14 17 18 20 21 , 0 6 1 7 2 5 7 16 8 12 13 15 17 20 20 22 , 0 4 6 10 8 11 12 14 16 18 17 19 , 0 5 6 12 12 15 16 20 , 0 8 1 9 7 17 8 13 17 21 , 1 16 17 22 , 6 8 10 11 12 13 16 17 18 19 , 23 0 , 23 1 | ||
nv nv -v +v nv -v -v v z v z nv w w <EOT> 0 1 2 3 4 5 6 7 8 8 7 4 5 6 12 <EOT> 2 4 3 5 5 7 6 8 , 2 11 3 12 5 10 6 9 , 0 4 1 5 2 6 3 7 4 8 , 0 11 1 12 3 10 4 9 , 2 13 , , 0 2 1 3 4 6 11 13 , 14 0 , 14 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
<unk> | ||
<blank> | ||
<s> | ||
</s> | ||
+s | ||
-s | ||
*s | ||
/s | ||
is | ||
ns | ||
+m | ||
-m | ||
*m | ||
tm | ||
im | ||
nm | ||
+v | ||
-v | ||
*v | ||
nv | ||
Null | ||
a | ||
b | ||
c | ||
d | ||
e | ||
0 | ||
1 | ||
A | ||
B | ||
C | ||
D | ||
E | ||
O | ||
I | ||
v | ||
w | ||
x | ||
y | ||
z | ||
o | ||
left | ||
right | ||
Cancel | ||
Noop | ||
Double | ||
Commute | ||
Distribleft | ||
Distribright | ||
Factorleft | ||
Factorright | ||
Assocleft | ||
Assocright | ||
Flipleft | ||
Flipright | ||
Transpose | ||
<EOT> | ||
, | ||
2 | ||
3 | ||
4 | ||
5 | ||
6 | ||
7 | ||
8 | ||
9 | ||
10 | ||
11 | ||
12 | ||
13 | ||
14 | ||
15 | ||
16 | ||
17 | ||
18 | ||
19 | ||
20 | ||
21 | ||
22 | ||
23 | ||
24 | ||
25 | ||
26 | ||
27 | ||
28 | ||
29 | ||
30 | ||
31 | ||
32 | ||
33 | ||
34 | ||
35 | ||
36 | ||
37 | ||
38 | ||
39 | ||
40 | ||
41 | ||
42 | ||
43 | ||
44 | ||
45 | ||
46 | ||
47 | ||
48 | ||
49 | ||
50 | ||
51 | ||
52 | ||
53 | ||
54 | ||
55 | ||
56 | ||
57 | ||
58 | ||
59 | ||
60 | ||
61 | ||
62 | ||
63 | ||
64 | ||
65 | ||
66 | ||
67 | ||
68 | ||
69 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
Double Distribright right Noop | ||
Flipright left right left right Noop | ||
left right Assocright left right right Distribleft | ||
Distribright left Distribleft right Distribleft | ||
Distribright left left Double | ||
left right Noop | ||
right Flipleft | ||
left left right right Noop right Assocright | ||
Assocright right right left Flipright | ||
left Flipright left left Flipleft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
<unk> | ||
<blank> | ||
<s> | ||
</s> | ||
left | ||
right | ||
Cancel | ||
Noop | ||
Double | ||
Commute | ||
Distribleft | ||
Distribright | ||
Factorleft | ||
Factorright | ||
Assocleft | ||
Assocright | ||
Flipleft | ||
Flipright | ||
Transpose |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Gated Graph Sequence Neural Networks | ||
|
||
Graph-to-sequence networks allow information represtable as a graph (such as an annotated NLP sentence or computer code structure as an AST) to be connected to a sequence generator to produce output which can benefit from the graph structure of the input. | ||
|
||
The training option `-encoder_type ggnn` implements a GGNN (Gated Graph Neural Network) based on github.com/JamesChuanggg/ggnn.pytorch.git which is based on the paper "Gated Graph Sequence Neural Networks" by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel. | ||
|
||
The ggnn encoder is used for program equivalence proof generation in the paper <a href="https://arxiv.org/abs/2002.06799">Equivalence of Dataflow Graphs via Rewrite Rules Using a Graph-to-Sequence Neural Model</a>. That paper shows the benefit of the graph-to-sequence model over a sequence-to-sequence model for this problem which can be well represented with graphical input. The integration of the ggnn network into the <a href="https://github.com/OpenNMT/OpenNMT-py/">OpenNMT-py</a> system supports attention on the nodes as well as a copy mechanism. | ||
|
||
### Dependencies | ||
|
||
* There are no additional dependencies beyond the rnn-to-rnn sequeence2sequence requirements. | ||
|
||
### Quick Start | ||
|
||
To get started, we provide a toy graph-to-sequence example. We assume that the working directory is `OpenNMT-py` throughout this document. | ||
|
||
0) Download the data to a sibling directory. | ||
|
||
``` | ||
cd .. | ||
git clone https://github.com/SteveKommrusch/OpenNMT-py-ggnn-example | ||
source OpenNMT-py-ggnn-example/env.sh | ||
cd OpenNMT-py | ||
``` | ||
|
||
|
||
1) Preprocess the data. | ||
|
||
``` | ||
python preprocess.py -train_src $data_path/src-train.txt -train_tgt $data_path/tgt-train.txt -valid_src $data_path/src-val.txt -valid_tgt $data_path/tgt-val.txt -src_seq_length 1000 -tgt_seq_length 30 -src_vocab $data_path/srcvocab.txt -tgt_vocab $data_path/tgtvocab.txt -dynamic_dict -save_data $data_path/final 2>&1 > $data_path/preprocess.out | ||
``` | ||
|
||
2) Train the model. | ||
|
||
``` | ||
python train.py -data $data_path/final -encoder_type ggnn -layers 2 -decoder_type rnn -rnn_size 256 -learning_rate 0.1 -start_decay_steps 5000 -learning_rate_decay 0.8 -global_attention general -batch_size 32 -word_vec_size 256 -bridge -train_steps 10000 -gpu_ranks 0 -save_checkpoint_steps 5000 -save_model $data_path/final-model -src_vocab $data_path/srcvocab.txt -n_edge_types 9 -state_dim 256 -n_steps 10 -n_node 64 > $data_path/train.final.out | ||
``` | ||
|
||
3) Translate the graph of 2 equivalent linear algebra expressions into the axiom list which proves them equivalent. | ||
|
||
``` | ||
python translate.py -model $data_path/final-model_step_10000.pt -src $data_path/src-test.txt -beam_size 5 -n_best 5 -gpu 0 -output $data_path/pred-test_beam5.txt -dynamic_dict 2>&1 > $data_path/translate5.out | ||
``` | ||
|
||
### Graph data format | ||
|
||
The GGNN implementation leverages the sequence processing and vocabulary | ||
interface of OpenNMT. Each graph is provided on an input line, much like | ||
a sentence is provided on an input line. A graph nearal network input line | ||
includes `sentence tokens`, `feature values`, and `edges` separated by | ||
`<EOT>` (end of tokens) tokens. Below is example of the input for a pair | ||
of algebraic equations structured as a graph: | ||
|
||
``` | ||
Sentence tokens Feature values Edges | ||
--------------- ------------------ ------------------------------------------------------- | ||
- - - 0 a a b b <EOT> 0 1 2 3 4 4 2 3 12 <EOT> 0 2 1 3 2 4 , 0 6 1 7 2 5 , 0 4 , 0 5 , , , , 8 0 , 8 1 | ||
``` | ||
|
||
The equations being represented are `((a - a) - b)` and `(0 - b)`, the | ||
`sentence tokens` of which are provided before the first `<EOT>`. After | ||
the first `<EOT>`, the `features values` are provided. These are extra | ||
flags with information on each node in the graph. In this case, the 8 | ||
sentence tokens have feature flags ranging from 0 to 4; the 9th feature | ||
flag defines a 9th node in the graph which does not have sentence token | ||
information, just feature data. Nodes with any non-number flag (such as | ||
`-` or `.`) will not have a feature added. Multiple groups of features | ||
can be provided by using the `,` delimiter between the first and second | ||
'<EOT>' tokens. After the second `<EOT>` token, edge information is provided. | ||
Edge data is given as node pairs, hence `<EOT> 0 2 1 3` indicates that there | ||
are edges from node 0 to node 2 and from node 1 to node 3. The GGNN supports | ||
multiple edge types (which result mathematically in multiple weight matrices | ||
for the model) and the edge types are separated by `,` tokens after the | ||
second `<EOT>` token. | ||
|
||
Note that the source vocabulary file needs to include the '<EOT>' token, | ||
the ',' token, and all of the numbers used for feature flags and node | ||
identifiers in the edge list. | ||
|
||
|
||
### Options | ||
|
||
* `-rnn_type (str)`: style of recurrent unit to use, one of [LSTM] | ||
* `-state_dim (int)`: Number of state dimensions in nodes | ||
* `-n_edge_types (int)`: Number of edge types | ||
* `-bidir_edges (bool)`: True if reverse edges should be automatically created | ||
* `-n_node (int)`: Max nodes in graph | ||
* `-bridge_extra_node (bool)`: True indicates only the vector from the 1st extra node (after token listing) should be used for decoder initialization; False indicates all node vectors should be averaged together for decoder initialization | ||
* `-n_steps (int)`: Steps to advance graph encoder for stabilization | ||
* `-src_vocab (int)`: Path to source vocabulary | ||
|
||
### Acknowledgement | ||
|
||
This gated graph neural network is leveraged from github.com/JamesChuanggg/ggnn.pytorch.git which is based on the paper "Gated Graph Sequence Neural Networks" by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,17 @@ | ||
"""Module defining encoders.""" | ||
from onmt.encoders.encoder import EncoderBase | ||
from onmt.encoders.transformer import TransformerEncoder | ||
from onmt.encoders.ggnn_encoder import GGNNEncoder | ||
from onmt.encoders.rnn_encoder import RNNEncoder | ||
from onmt.encoders.cnn_encoder import CNNEncoder | ||
from onmt.encoders.mean_encoder import MeanEncoder | ||
from onmt.encoders.audio_encoder import AudioEncoder | ||
from onmt.encoders.image_encoder import ImageEncoder | ||
|
||
|
||
str2enc = {"rnn": RNNEncoder, "brnn": RNNEncoder, "cnn": CNNEncoder, | ||
"transformer": TransformerEncoder, "img": ImageEncoder, | ||
"audio": AudioEncoder, "mean": MeanEncoder} | ||
str2enc = {"ggnn": GGNNEncoder, "rnn": RNNEncoder, "brnn": RNNEncoder, | ||
"cnn": CNNEncoder, "transformer": TransformerEncoder, | ||
"img": ImageEncoder, "audio": AudioEncoder, "mean": MeanEncoder} | ||
|
||
__all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", | ||
"MeanEncoder", "str2enc"] |
Oops, something went wrong.