Skip to content

Commit 56c384c

Browse files
committed
Added interactive answer prediction
1 parent 3f8ec80 commit 56c384c

10 files changed

+464
-128
lines changed

code/graph.py

+38-40
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,11 @@ def init_nodes(self):
136136
config.decay_rate,
137137
staircase=True
138138
)
139-
if config.optimizer == "adamax":
140-
print "\nUsing AdaMax Optimizer with lr: %f, decay_steps: %d, decay_rate: %f\n" \
141-
% (config.learning_rate, config.decay_steps, config.decay_rate)
142139

143-
self.optimizer = tf.keras.optimizers.Adamax(learning_rate)
140+
print "\nUsing Adam Optimizer with lr: %f, decay_steps: %d, decay_rate: %f\n" \
141+
% (config.learning_rate, config.decay_steps, config.decay_rate)
144142

145-
else:
146-
print "\nUsing Adam Optimizer with lr: %f, decay_steps: %d, decay_rate: %f\n" \
147-
% (config.learning_rate, config.decay_steps, config.decay_rate)
148-
149-
self.optimizer = tf.train.AdamOptimizer(learning_rate)
143+
self.optimizer = tf.train.AdamOptimizer(learning_rate)
150144

151145
gradients = self.optimizer.compute_gradients(self.loss)
152146
clipped_gradients = [
@@ -183,40 +177,44 @@ def predict(self, sess, dataset, msg):
183177
answers = []
184178
ground_answers = []
185179

186-
with tqdm(dataset, desc=msg) as pbar:
187-
for batch in pbar:
188-
questions_padded, questions_length = pad_sequences(
189-
batch[:, 0], config.max_question_length
190-
)
191-
contexts_padded, contexts_length = pad_sequences(
192-
batch[:, 1], config.max_context_length
193-
)
180+
if msg != None:
181+
pbar = tqdm(dataset, desc=msg)
182+
else:
183+
pbar = dataset
194184

195-
labels = np.zeros(
196-
(len(batch), config.n_clusters), dtype=np.float32
197-
)
198-
if config.clustering:
199-
for j, el in enumerate(batch):
200-
labels[j, el[3]] = 1
201-
else:
202-
labels[:, 0] = 1
185+
for batch in pbar:
186+
questions_padded, questions_length = pad_sequences(
187+
batch[:, 0], config.max_question_length
188+
)
189+
contexts_padded, contexts_length = pad_sequences(
190+
batch[:, 1], config.max_context_length
191+
)
203192

204-
predictions = sess.run(
205-
self.predictions,
206-
feed_dict={
207-
self.questions_ids: questions_padded,
208-
self.questions_length: questions_length,
209-
self.questions_mask: masks(questions_length, config.max_question_length),
210-
self.contexts_ids: contexts_padded,
211-
self.contexts_length: contexts_length,
212-
self.contexts_mask: masks(contexts_length, config.max_context_length),
213-
self.labels: labels,
214-
self.dropout: 1.0
215-
}
216-
)
193+
labels = np.zeros(
194+
(len(batch), config.n_clusters), dtype=np.float32
195+
)
196+
if config.clustering:
197+
for j, el in enumerate(batch):
198+
labels[j, el[3]] = 1
199+
else:
200+
labels[:, 0] = 1
201+
202+
predictions = sess.run(
203+
self.predictions,
204+
feed_dict={
205+
self.questions_ids: questions_padded,
206+
self.questions_length: questions_length,
207+
self.questions_mask: masks(questions_length, config.max_question_length),
208+
self.contexts_ids: contexts_padded,
209+
self.contexts_length: contexts_length,
210+
self.contexts_mask: masks(contexts_length, config.max_context_length),
211+
self.labels: labels,
212+
self.dropout: 1.0
213+
}
214+
)
217215

218-
answers += get_answers(predictions[0], predictions[1])
219-
ground_answers += [np.array(el[2]) for el in batch]
216+
answers += get_answers(predictions[0], predictions[1])
217+
ground_answers += [np.array(el[2]) for el in batch]
220218

221219
return np.array(answers, dtype=np.float32), np.array(ground_answers, dtype=np.float32)
222220

code/includes/config.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@
3030

3131
clustering = True
3232

33+
model_name = "k-match-lstm"
34+
3335
data_dir = "data/squad/"
34-
train_dir = "model/k-match-lstm.clustered.weighted"
36+
train_dir = "model/" + model_name + "/"
3537

3638
if not os.path.exists(train_dir):
3739
os.makedirs(train_dir)
3840

39-
plots_dir = "data/plots/"
41+
plots_dir = "data/plots." + model_name + "/"
4042

4143
if not os.path.exists(plots_dir):
4244
os.makedirs(plots_dir)
@@ -48,17 +50,16 @@
4850
embed_path = data_dir + "/glove.npz"
4951

5052
dropout_keep_prob = 0.9
51-
# regularization_constant = 0.001
5253

5354
train_embeddings = False
5455

55-
optimizer = "adamax"
56+
optimizer = "adam"
5657

5758
learning_rate = 0.002
5859
decay_steps = 1000
5960
decay_rate = 0.92
6061

61-
max_gradient = 10.0
62+
max_gradient = 5.0
6263

6364
load_model = True
6465

0 commit comments

Comments
 (0)