36
36
from ..modules .utils .tree_utils import Tree
37
37
from ..modules .utils .tree_utils import Vocab as VocabForTree
38
38
from ..modules .utils .tree_utils import VocabForAll
39
- from ..modules .utils .vocab_utils import VocabModel
39
+ from ..modules .utils .vocab_utils import Vocab , VocabModel
40
40
41
41
42
42
class DataItem (object ):
@@ -146,6 +146,16 @@ def extract(self):
146
146
output_tokens = self .tokenizer (self .output_text )
147
147
148
148
return input_tokens , output_tokens
149
+
150
+ def extract_edge_tokens (self ):
151
+ g : GraphData = self .graph
152
+ edge_tokens = []
153
+ for i in range (g .get_edge_num ()):
154
+ if "token" in g .edge_attributes [i ]:
155
+ edge_tokens .append (g .edge_attributes [i ]["token" ])
156
+ else :
157
+ edge_tokens .append ("" )
158
+ return edge_tokens
149
159
150
160
151
161
class Text2LabelDataItem (DataItem ):
@@ -312,6 +322,7 @@ def __init__(
312
322
reused_vocab_model = None ,
313
323
nlp_processor_args = None ,
314
324
init_edge_vocab = False ,
325
+ is_hetero = False ,
315
326
** kwargs ,
316
327
):
317
328
"""
@@ -358,6 +369,10 @@ def __init__(
358
369
vocabulary data is located.
359
370
nlp_processor_args: dict, default=None
360
371
It contains the parameter for nlp processor such as ``stanza``.
372
+ init_edge_vocab: bool, default=False
373
+ Whether to initialize the edge vocabulary.
374
+ is_hetero: bool, default=False
375
+ Whether the graph is heterogeneous.
361
376
kwargs
362
377
"""
363
378
super (Dataset , self ).__init__ ()
@@ -387,6 +402,7 @@ def __init__(
387
402
self .topology_subdir = topology_subdir
388
403
self .use_val_for_vocab = use_val_for_vocab
389
404
self .init_edge_vocab = init_edge_vocab
405
+ self .is_hetero = is_hetero
390
406
for k , v in kwargs .items ():
391
407
setattr (self , k , v )
392
408
self .__indices__ = None
@@ -428,8 +444,6 @@ def __init__(
428
444
429
445
vocab = torch .load (self .processed_file_paths ["vocab" ])
430
446
self .vocab_model = vocab
431
- if init_edge_vocab :
432
- self .edge_vocab = torch .load (self .processed_file_paths ["edge_vocab" ])
433
447
434
448
if hasattr (self , "reused_label_model" ):
435
449
self .label_model = LabelModel .build (self .processed_file_paths ["label" ])
@@ -663,6 +677,7 @@ def build_vocab(self):
663
677
target_pretrained_word_emb_name = self .target_pretrained_word_emb_name ,
664
678
target_pretrained_word_emb_url = self .target_pretrained_word_emb_url ,
665
679
word_emb_size = self .word_emb_size ,
680
+ init_edge_vocab = self .init_edge_vocab ,
666
681
)
667
682
self .vocab_model = vocab_model
668
683
@@ -709,41 +724,6 @@ def _process(self):
709
724
self .test = self .build_topology (self .test )
710
725
if "val" in self .__dict__ :
711
726
self .val = self .build_topology (self .val )
712
- # build_edge_vocab and save
713
- if self .init_edge_vocab :
714
- self .edge_vocab = {}
715
- s = set ()
716
- try :
717
- for i in self .train :
718
- graph = i .graph
719
- for edge_idx in range (graph .get_edge_num ()):
720
- if "token" in graph .edge_attributes [edge_idx ]:
721
- edge_token = graph .edge_attributes [edge_idx ]["token" ]
722
- s .add (edge_token )
723
- except Exception as e :
724
- pass
725
- try :
726
- for i in self .test :
727
- graph = i .graph
728
- for edge_idx in range (graph .get_edge_num ()):
729
- if "token" in graph .edge_attributes [edge_idx ]:
730
- edge_token = graph .edge_attributes [edge_idx ]["token" ]
731
- s .add (edge_token )
732
- except Exception as e :
733
- pass
734
- try :
735
- for i in self .val :
736
- graph = i .graph
737
- for edge_idx in range (graph .get_edge_num ()):
738
- if "token" in graph .edge_attributes [edge_idx ]:
739
- edge_token = graph .edge_attributes [edge_idx ]["token" ]
740
- s .add (edge_token )
741
- except Exception as e :
742
- pass
743
- s .add ("" )
744
- self .edge_vocab = {v : k for k , v in enumerate (s )}
745
- print ('edge vocab size:' , len (self .edge_vocab ))
746
- torch .save (self .edge_vocab , self .processed_file_paths ["edge_vocab" ])
747
727
748
728
self .build_vocab ()
749
729
@@ -1116,6 +1096,11 @@ def build_vocab(self):
1116
1096
pretrained_word_emb_cache_dir = self .pretrained_word_emb_cache_dir ,
1117
1097
embedding_dims = self .dec_emb_size ,
1118
1098
)
1099
+ if self .init_edge_vocab :
1100
+ all_edge_words = VocabModel .collect_edge_vocabs (data_for_vocab , self .tokenizer , lower_case = self .lower_case )
1101
+ edge_vocab = Vocab (lower_case = self .lower_case , tokenizer = self .tokenizer )
1102
+ edge_vocab .build_vocab (all_edge_words , max_vocab_size = None , min_vocab_freq = 1 )
1103
+ edge_vocab .randomize_embeddings (self .word_emb_size )
1119
1104
1120
1105
if self .share_vocab :
1121
1106
all_words = Counter ()
@@ -1158,6 +1143,7 @@ def build_vocab(self):
1158
1143
in_word_vocab = src_vocab_model ,
1159
1144
out_word_vocab = tgt_vocab_model ,
1160
1145
share_vocab = src_vocab_model if self .share_vocab else None ,
1146
+ edge_vocab = edge_vocab if self .init_edge_vocab else None ,
1161
1147
)
1162
1148
1163
1149
return self .vocab_model
@@ -1175,6 +1161,18 @@ def vectorization(self, data_items):
1175
1161
token_matrix = torch .tensor (token_matrix , dtype = torch .long )
1176
1162
graph .node_features ["token_id" ] = token_matrix
1177
1163
1164
+ if self .is_hetero :
1165
+ for edge_idx in range (graph .get_edge_num ()):
1166
+ if "token" in graph .edge_attributes [edge_idx ]:
1167
+ edge_token = graph .edge_attributes [edge_idx ]["token" ]
1168
+ else :
1169
+ edge_token = ""
1170
+ edge_token_id = self .edge_vocab [edge_token ]
1171
+ graph .edge_attributes [edge_idx ]["token_id" ] = edge_token_id
1172
+ token_matrix .append ([edge_token_id ])
1173
+ token_matrix = torch .tensor (token_matrix , dtype = torch .long )
1174
+ graph .edge_features ["token_id" ] = token_matrix
1175
+
1178
1176
tgt = item .output_text
1179
1177
tgt_list = self .vocab_model .out_word_vocab .get_symbol_idx_for_list (tgt .split ())
1180
1178
output_tree = Tree .convert_to_tree (
@@ -1183,7 +1181,7 @@ def vectorization(self, data_items):
1183
1181
item .output_tree = output_tree
1184
1182
1185
1183
@classmethod
1186
- def _vectorize_one_dataitem (cls , data_item , vocab_model , use_ie = False ):
1184
+ def _vectorize_one_dataitem (cls , data_item , vocab_model , use_ie = False , is_hetero = False ):
1187
1185
item = deepcopy (data_item )
1188
1186
graph : GraphData = item .graph
1189
1187
token_matrix = []
@@ -1195,6 +1193,21 @@ def _vectorize_one_dataitem(cls, data_item, vocab_model, use_ie=False):
1195
1193
token_matrix = torch .tensor (token_matrix , dtype = torch .long )
1196
1194
graph .node_features ["token_id" ] = token_matrix
1197
1195
1196
+ if is_hetero :
1197
+ if not hasattr (vocab_model , "edge_vocab" ):
1198
+ raise ValueError ("Vocab model must have edge vocab attribute" )
1199
+ token_matrix = []
1200
+ for edge_idx in range (graph .get_edge_num ()):
1201
+ if "token" in graph .edge_attributes [edge_idx ]:
1202
+ edge_token = graph .edge_attributes [edge_idx ]["token" ]
1203
+ else :
1204
+ edge_token = ""
1205
+ edge_token_id = vocab_model .edge_vocab [edge_token ]
1206
+ graph .edge_attributes [edge_idx ]["token_id" ] = edge_token_id
1207
+ token_matrix .append ([edge_token_id ])
1208
+ token_matrix = torch .tensor (token_matrix , dtype = torch .long )
1209
+ graph .edge_features ["token_id" ] = token_matrix
1210
+
1198
1211
if isinstance (item .output_text , str ):
1199
1212
tgt = item .output_text
1200
1213
tgt_list = vocab_model .out_word_vocab .get_symbol_idx_for_list (tgt .split ())
0 commit comments