Skip to content

Commit 0d0400c

Browse files
committed
fix the library
1 parent 3cd1114 commit 0d0400c

File tree

5 files changed

+76
-124
lines changed

5 files changed

+76
-124
lines changed

graph4nlp/pytorch/data/dataset.py

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from ..modules.utils.tree_utils import Tree
3737
from ..modules.utils.tree_utils import Vocab as VocabForTree
3838
from ..modules.utils.tree_utils import VocabForAll
39-
from ..modules.utils.vocab_utils import VocabModel
39+
from ..modules.utils.vocab_utils import Vocab, VocabModel
4040

4141

4242
class DataItem(object):
@@ -146,6 +146,16 @@ def extract(self):
146146
output_tokens = self.tokenizer(self.output_text)
147147

148148
return input_tokens, output_tokens
149+
150+
def extract_edge_tokens(self):
151+
g: GraphData = self.graph
152+
edge_tokens = []
153+
for i in range(g.get_edge_num()):
154+
if "token" in g.edge_attributes[i]:
155+
edge_tokens.append(g.edge_attributes[i]["token"])
156+
else:
157+
edge_tokens.append("")
158+
return edge_tokens
149159

150160

151161
class Text2LabelDataItem(DataItem):
@@ -312,6 +322,7 @@ def __init__(
312322
reused_vocab_model=None,
313323
nlp_processor_args=None,
314324
init_edge_vocab=False,
325+
is_hetero=False,
315326
**kwargs,
316327
):
317328
"""
@@ -358,6 +369,10 @@ def __init__(
358369
vocabulary data is located.
359370
nlp_processor_args: dict, default=None
360371
It contains the parameter for nlp processor such as ``stanza``.
372+
init_edge_vocab: bool, default=False
373+
Whether to initialize the edge vocabulary.
374+
is_hetero: bool, default=False
375+
Whether the graph is heterogeneous.
361376
kwargs
362377
"""
363378
super(Dataset, self).__init__()
@@ -387,6 +402,7 @@ def __init__(
387402
self.topology_subdir = topology_subdir
388403
self.use_val_for_vocab = use_val_for_vocab
389404
self.init_edge_vocab = init_edge_vocab
405+
self.is_hetero = is_hetero
390406
for k, v in kwargs.items():
391407
setattr(self, k, v)
392408
self.__indices__ = None
@@ -428,8 +444,6 @@ def __init__(
428444

429445
vocab = torch.load(self.processed_file_paths["vocab"])
430446
self.vocab_model = vocab
431-
if init_edge_vocab:
432-
self.edge_vocab = torch.load(self.processed_file_paths["edge_vocab"])
433447

434448
if hasattr(self, "reused_label_model"):
435449
self.label_model = LabelModel.build(self.processed_file_paths["label"])
@@ -663,6 +677,7 @@ def build_vocab(self):
663677
target_pretrained_word_emb_name=self.target_pretrained_word_emb_name,
664678
target_pretrained_word_emb_url=self.target_pretrained_word_emb_url,
665679
word_emb_size=self.word_emb_size,
680+
init_edge_vocab=self.init_edge_vocab,
666681
)
667682
self.vocab_model = vocab_model
668683

@@ -709,41 +724,6 @@ def _process(self):
709724
self.test = self.build_topology(self.test)
710725
if "val" in self.__dict__:
711726
self.val = self.build_topology(self.val)
712-
# build_edge_vocab and save
713-
if self.init_edge_vocab:
714-
self.edge_vocab = {}
715-
s = set()
716-
try:
717-
for i in self.train:
718-
graph = i.graph
719-
for edge_idx in range(graph.get_edge_num()):
720-
if "token" in graph.edge_attributes[edge_idx]:
721-
edge_token = graph.edge_attributes[edge_idx]["token"]
722-
s.add(edge_token)
723-
except Exception as e:
724-
pass
725-
try:
726-
for i in self.test:
727-
graph = i.graph
728-
for edge_idx in range(graph.get_edge_num()):
729-
if "token" in graph.edge_attributes[edge_idx]:
730-
edge_token = graph.edge_attributes[edge_idx]["token"]
731-
s.add(edge_token)
732-
except Exception as e:
733-
pass
734-
try:
735-
for i in self.val:
736-
graph = i.graph
737-
for edge_idx in range(graph.get_edge_num()):
738-
if "token" in graph.edge_attributes[edge_idx]:
739-
edge_token = graph.edge_attributes[edge_idx]["token"]
740-
s.add(edge_token)
741-
except Exception as e:
742-
pass
743-
s.add("")
744-
self.edge_vocab = {v: k for k, v in enumerate(s)}
745-
print('edge vocab size:', len(self.edge_vocab))
746-
torch.save(self.edge_vocab, self.processed_file_paths["edge_vocab"])
747727

748728
self.build_vocab()
749729

@@ -1116,6 +1096,11 @@ def build_vocab(self):
11161096
pretrained_word_emb_cache_dir=self.pretrained_word_emb_cache_dir,
11171097
embedding_dims=self.dec_emb_size,
11181098
)
1099+
if self.init_edge_vocab:
1100+
all_edge_words = VocabModel.collect_edge_vocabs(data_for_vocab, self.tokenizer, lower_case=self.lower_case)
1101+
edge_vocab = Vocab(lower_case=self.lower_case, tokenizer=self.tokenizer)
1102+
edge_vocab.build_vocab(all_edge_words, max_vocab_size=None, min_vocab_freq=1)
1103+
edge_vocab.randomize_embeddings(self.word_emb_size)
11191104

11201105
if self.share_vocab:
11211106
all_words = Counter()
@@ -1158,6 +1143,7 @@ def build_vocab(self):
11581143
in_word_vocab=src_vocab_model,
11591144
out_word_vocab=tgt_vocab_model,
11601145
share_vocab=src_vocab_model if self.share_vocab else None,
1146+
edge_vocab=edge_vocab if self.init_edge_vocab else None,
11611147
)
11621148

11631149
return self.vocab_model
@@ -1175,6 +1161,18 @@ def vectorization(self, data_items):
11751161
token_matrix = torch.tensor(token_matrix, dtype=torch.long)
11761162
graph.node_features["token_id"] = token_matrix
11771163

1164+
if self.is_hetero:
1165+
for edge_idx in range(graph.get_edge_num()):
1166+
if "token" in graph.edge_attributes[edge_idx]:
1167+
edge_token = graph.edge_attributes[edge_idx]["token"]
1168+
else:
1169+
edge_token = ""
1170+
edge_token_id = self.edge_vocab[edge_token]
1171+
graph.edge_attributes[edge_idx]["token_id"] = edge_token_id
1172+
token_matrix.append([edge_token_id])
1173+
token_matrix = torch.tensor(token_matrix, dtype=torch.long)
1174+
graph.edge_features["token_id"] = token_matrix
1175+
11781176
tgt = item.output_text
11791177
tgt_list = self.vocab_model.out_word_vocab.get_symbol_idx_for_list(tgt.split())
11801178
output_tree = Tree.convert_to_tree(
@@ -1183,7 +1181,7 @@ def vectorization(self, data_items):
11831181
item.output_tree = output_tree
11841182

11851183
@classmethod
1186-
def _vectorize_one_dataitem(cls, data_item, vocab_model, use_ie=False):
1184+
def _vectorize_one_dataitem(cls, data_item, vocab_model, use_ie=False, is_hetero=False):
11871185
item = deepcopy(data_item)
11881186
graph: GraphData = item.graph
11891187
token_matrix = []
@@ -1195,6 +1193,21 @@ def _vectorize_one_dataitem(cls, data_item, vocab_model, use_ie=False):
11951193
token_matrix = torch.tensor(token_matrix, dtype=torch.long)
11961194
graph.node_features["token_id"] = token_matrix
11971195

1196+
if is_hetero:
1197+
if not hasattr(vocab_model, "edge_vocab"):
1198+
raise ValueError("Vocab model must have edge vocab attribute")
1199+
token_matrix = []
1200+
for edge_idx in range(graph.get_edge_num()):
1201+
if "token" in graph.edge_attributes[edge_idx]:
1202+
edge_token = graph.edge_attributes[edge_idx]["token"]
1203+
else:
1204+
edge_token = ""
1205+
edge_token_id = vocab_model.edge_vocab[edge_token]
1206+
graph.edge_attributes[edge_idx]["token_id"] = edge_token_id
1207+
token_matrix.append([edge_token_id])
1208+
token_matrix = torch.tensor(token_matrix, dtype=torch.long)
1209+
graph.edge_features["token_id"] = token_matrix
1210+
11981211
if isinstance(item.output_text, str):
11991212
tgt = item.output_text
12001213
tgt_list = vocab_model.out_word_vocab.get_symbol_idx_for_list(tgt.split())

graph4nlp/pytorch/datasets/mawps.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
for_inference=False,
5252
reused_vocab_model=None,
5353
init_edge_vocab=False,
54+
is_hetero=False,
5455
):
5556
"""
5657
Parameters
@@ -120,4 +121,5 @@ def __init__(
120121
for_inference=for_inference,
121122
reused_vocab_model=reused_vocab_model,
122123
init_edge_vocab=init_edge_vocab,
124+
is_hetero=is_hetero,
123125
)

graph4nlp/pytorch/modules/graph_embedding_initialization/embedding_construction.py

Lines changed: 3 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,6 @@ def __init__(
140140
"w2v_bert",
141141
"w2v_bert_bilstm",
142142
"w2v_bert_bigru",
143-
"w2v_amr",
144-
"w2v_bilstm_amr",
145-
"w2v_bilstm_amr_pos",
146143
), "emb_strategy must be one of ('w2v', 'w2v_bilstm', 'w2v_bigru', 'bert', 'bert_bilstm', "
147144
"'bert_bigru', 'w2v_bert', 'w2v_bert_bilstm', 'w2v_bert_bigru')"
148145

@@ -163,14 +160,6 @@ def __init__(
163160
seq_info_encode_strategy = "none"
164161
else:
165162
seq_info_encode_strategy = "none"
166-
if "amr" in emb_strategy:
167-
seq_info_encode_strategy = "bilstm"
168-
169-
if "pos" in emb_strategy:
170-
word_emb_type.add("pos")
171-
#word_emb_type.add("entity_label")
172-
word_emb_type.add("position")
173-
174163
if "w2v" in emb_strategy:
175164
word_emb_type.add("w2v")
176165

@@ -222,17 +211,6 @@ def __init__(
222211
else:
223212
rnn_input_size = word_emb_size
224213

225-
if "pos" in word_emb_type:
226-
self.word_emb_layers["pos"] = WordEmbedding(37, 50)
227-
rnn_input_size += 50
228-
229-
if "entity_label" in word_emb_type:
230-
self.word_emb_layers["entity_label"] = WordEmbedding(26, 50)
231-
rnn_input_size += 50
232-
233-
if "position" in word_emb_type:
234-
pass
235-
236214
if "seq_bert" in word_emb_type:
237215
rnn_input_size += self.word_emb_layers["seq_bert"].bert_model.config.hidden_size
238216

@@ -250,8 +228,6 @@ def __init__(
250228
else:
251229
self.output_size = rnn_input_size
252230
self.seq_info_encode_layer = None
253-
254-
#self.fc = nn.Linear(376, 300)
255231

256232
def forward(self, batch_gd):
257233
"""Compute initial node/edge embeddings.
@@ -284,60 +260,6 @@ def forward(self, batch_gd):
284260
word_feat, self.word_dropout, shared_axes=[-2], training=self.training
285261
)
286262
feat.append(word_feat)
287-
if any(batch_gd.batch_graph_attributes):
288-
tot = 0
289-
gd_list = from_batch(batch_gd)
290-
for i, g in enumerate(gd_list):
291-
sentence_id = g.graph_attributes["sentence_id"].to(batch_gd.device)
292-
seq_feat = []
293-
if "w2v" in self.word_emb_layers:
294-
word_feat = self.word_emb_layers["w2v"](sentence_id)
295-
word_feat = dropout_fn(
296-
word_feat, self.word_dropout, shared_axes=[-2], training=self.training
297-
)
298-
seq_feat.append(word_feat)
299-
else:
300-
RuntimeError("No word embedding layer")
301-
if "pos" in self.word_emb_layers:
302-
sentence_pos = g.graph_attributes["pos_tag_id"].to(batch_gd.device)
303-
pos_feat = self.word_emb_layers["pos"](sentence_pos)
304-
pos_feat = dropout_fn(
305-
pos_feat, self.word_dropout, shared_axes=[-2], training=self.training
306-
)
307-
seq_feat.append(pos_feat)
308-
309-
if "entity_label" in self.word_emb_layers:
310-
sentence_entity_label = g.graph_attributes["entity_label_id"].to(batch_gd.device)
311-
entity_label_feat = self.word_emb_layers["entity_label"](sentence_entity_label)
312-
entity_label_feat = dropout_fn(
313-
entity_label_feat, self.word_dropout, shared_axes=[-2], training=self.training
314-
)
315-
seq_feat.append(entity_label_feat)
316-
317-
seq_feat = torch.cat(seq_feat, dim=-1)
318-
319-
raw_tokens = [dd.strip().split() for dd in g.graph_attributes["sentence"]]
320-
l = [len(s) for s in raw_tokens]
321-
rnn_state = self.seq_info_encode_layer(
322-
seq_feat, torch.LongTensor(l).to(batch_gd.device)
323-
)
324-
if isinstance(rnn_state, (tuple, list)):
325-
rnn_state = rnn_state[0]
326-
327-
# update node features
328-
for j in range(g.get_node_num()):
329-
id = g.node_attributes[j]["sentence_id"]
330-
if g.node_attributes[j]["id"] in batch_gd.batch_graph_attributes[i]["mapping"][id]:
331-
rel_list = batch_gd.batch_graph_attributes[i]["mapping"][id][g.node_attributes[j]["id"]]
332-
state = []
333-
for rel in rel_list:
334-
if rel[1] == "node":
335-
state.append(rnn_state[id][rel[0]])
336-
# replace embedding of the node
337-
if len(state) > 0:
338-
feat[0][tot + j][0] = torch.stack(state, 0).mean(0)
339-
340-
tot += g.get_node_num()
341263

342264
if "node_edge_bert" in self.word_emb_layers:
343265
input_data = [
@@ -352,17 +274,14 @@ def forward(self, batch_gd):
352274

353275
if len(feat) > 0:
354276
feat = torch.cat(feat, dim=-1)
355-
if not any(batch_gd.batch_graph_attributes):
356-
node_token_lens = torch.clamp((token_ids != Vocab.PAD).sum(-1), min=1)
357-
feat = self.node_edge_emb_layer(feat, node_token_lens)
358-
else:
359-
feat = feat.squeeze(dim=1)
277+
node_token_lens = torch.clamp((token_ids != Vocab.PAD).sum(-1), min=1)
278+
feat = self.node_edge_emb_layer(feat, node_token_lens)
360279
if isinstance(feat, (tuple, list)):
361280
feat = feat[-1]
362281

363282
feat = batch_gd.split_features(feat)
364283

365-
if (self.seq_info_encode_layer is None and "seq_bert" not in self.word_emb_layers) or any(batch_gd.batch_graph_attributes):
284+
if self.seq_info_encode_layer is None and "seq_bert" not in self.word_emb_layers:
366285
if isinstance(feat, list):
367286
feat = torch.cat(feat, -1)
368287

graph4nlp/pytorch/modules/utils/tree_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,11 @@ def convert_to_tree(r_list, i_left, i_right, tgt_vocab):
132132

133133

134134
class VocabForAll:
135-
def __init__(self, in_word_vocab, out_word_vocab, share_vocab):
135+
def __init__(self, in_word_vocab, out_word_vocab, share_vocab, edge_vocab=None):
136136
self.in_word_vocab = in_word_vocab
137137
self.out_word_vocab = out_word_vocab
138138
self.share_vocab = share_vocab
139+
self.edge_vocab = edge_vocab
139140

140141
def get_vocab_size(self):
141142
if hasattr(self, "share_vocab"):

graph4nlp/pytorch/modules/utils/vocab_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class VocabModel(object):
4747
Word embedding size, default: ``None``.
4848
share_vocab : boolean
4949
Specify whether to share vocab between input and output text, default: ``True``.
50+
init_edge_vocab: boolean
51+
Specify whether to initialize edge vocab, default: ``False``.
5052
5153
Examples
5254
-------
@@ -82,6 +84,7 @@ def __init__(
8284
# pretrained_word_emb_file=None,
8385
word_emb_size=None,
8486
share_vocab=True,
87+
init_edge_vocab=False,
8588
):
8689
super(VocabModel, self).__init__()
8790
self.tokenizer = tokenizer
@@ -150,6 +153,12 @@ def __init__(
150153
self.out_word_vocab.randomize_embeddings(word_emb_size)
151154
else:
152155
self.out_word_vocab = self.in_word_vocab
156+
157+
if init_edge_vocab:
158+
all_edge_words = VocabModel.collect_edge_vocabs(data_set, self.tokenizer, lower_case=lower_case)
159+
self.edge_vocab = Vocab(lower_case=lower_case, tokenizer=self.tokenizer)
160+
self.edge_vocab.build_vocab(all_edge_words, max_vocab_size=None, min_vocab_freq=1)
161+
self.edge_vocab.randomize_embeddings(word_emb_size)
153162

154163
if share_vocab:
155164
print("[ Initialized word embeddings: {} ]".format(self.in_word_vocab.embeddings.shape))
@@ -265,6 +274,14 @@ def collect_vocabs(all_instances, tokenizer, lower_case=True, share_vocab=True):
265274
all_words[1].update(extracted_tokens[1])
266275

267276
return all_words
277+
@staticmethod
278+
def collect_edge_vocabs(all_instances, tokenizer, lower_case=True):
279+
"""Count vocabulary tokens for edge."""
280+
all_edges = Counter()
281+
for instance in all_instances:
282+
extracted_edge_tokens = instance.extract_edge_tokens()
283+
all_edges.update(extracted_edge_tokens)
284+
return all_edges
268285

269286

270287
class WordEmbModel(Vectors):

0 commit comments

Comments
 (0)