@@ -136,17 +136,11 @@ def init_nodes(self):
136
136
config .decay_rate ,
137
137
staircase = True
138
138
)
139
- if config .optimizer == "adamax" :
140
- print "\n Using AdaMax Optimizer with lr: %f, decay_steps: %d, decay_rate: %f\n " \
141
- % (config .learning_rate , config .decay_steps , config .decay_rate )
142
139
143
- self .optimizer = tf .keras .optimizers .Adamax (learning_rate )
140
+ print "\n Using Adam Optimizer with lr: %f, decay_steps: %d, decay_rate: %f\n " \
141
+ % (config .learning_rate , config .decay_steps , config .decay_rate )
144
142
145
- else :
146
- print "\n Using Adam Optimizer with lr: %f, decay_steps: %d, decay_rate: %f\n " \
147
- % (config .learning_rate , config .decay_steps , config .decay_rate )
148
-
149
- self .optimizer = tf .train .AdamOptimizer (learning_rate )
143
+ self .optimizer = tf .train .AdamOptimizer (learning_rate )
150
144
151
145
gradients = self .optimizer .compute_gradients (self .loss )
152
146
clipped_gradients = [
@@ -183,40 +177,44 @@ def predict(self, sess, dataset, msg):
183
177
answers = []
184
178
ground_answers = []
185
179
186
- with tqdm (dataset , desc = msg ) as pbar :
187
- for batch in pbar :
188
- questions_padded , questions_length = pad_sequences (
189
- batch [:, 0 ], config .max_question_length
190
- )
191
- contexts_padded , contexts_length = pad_sequences (
192
- batch [:, 1 ], config .max_context_length
193
- )
180
+ if msg != None :
181
+ pbar = tqdm (dataset , desc = msg )
182
+ else :
183
+ pbar = dataset
194
184
195
- labels = np .zeros (
196
- (len (batch ), config .n_clusters ), dtype = np .float32
197
- )
198
- if config .clustering :
199
- for j , el in enumerate (batch ):
200
- labels [j , el [3 ]] = 1
201
- else :
202
- labels [:, 0 ] = 1
185
+ for batch in pbar :
186
+ questions_padded , questions_length = pad_sequences (
187
+ batch [:, 0 ], config .max_question_length
188
+ )
189
+ contexts_padded , contexts_length = pad_sequences (
190
+ batch [:, 1 ], config .max_context_length
191
+ )
203
192
204
- predictions = sess .run (
205
- self .predictions ,
206
- feed_dict = {
207
- self .questions_ids : questions_padded ,
208
- self .questions_length : questions_length ,
209
- self .questions_mask : masks (questions_length , config .max_question_length ),
210
- self .contexts_ids : contexts_padded ,
211
- self .contexts_length : contexts_length ,
212
- self .contexts_mask : masks (contexts_length , config .max_context_length ),
213
- self .labels : labels ,
214
- self .dropout : 1.0
215
- }
216
- )
193
+ labels = np .zeros (
194
+ (len (batch ), config .n_clusters ), dtype = np .float32
195
+ )
196
+ if config .clustering :
197
+ for j , el in enumerate (batch ):
198
+ labels [j , el [3 ]] = 1
199
+ else :
200
+ labels [:, 0 ] = 1
201
+
202
+ predictions = sess .run (
203
+ self .predictions ,
204
+ feed_dict = {
205
+ self .questions_ids : questions_padded ,
206
+ self .questions_length : questions_length ,
207
+ self .questions_mask : masks (questions_length , config .max_question_length ),
208
+ self .contexts_ids : contexts_padded ,
209
+ self .contexts_length : contexts_length ,
210
+ self .contexts_mask : masks (contexts_length , config .max_context_length ),
211
+ self .labels : labels ,
212
+ self .dropout : 1.0
213
+ }
214
+ )
217
215
218
- answers += get_answers (predictions [0 ], predictions [1 ])
219
- ground_answers += [np .array (el [2 ]) for el in batch ]
216
+ answers += get_answers (predictions [0 ], predictions [1 ])
217
+ ground_answers += [np .array (el [2 ]) for el in batch ]
220
218
221
219
return np .array (answers , dtype = np .float32 ), np .array (ground_answers , dtype = np .float32 )
222
220
0 commit comments