Skip to content

Commit b83e517

Browse files
committedJul 28, 2017
almost ready for offline training
1 parent e2f0c84 commit b83e517

File tree

8 files changed

+13066
-49
lines changed

8 files changed

+13066
-49
lines changed
 

‎.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -119,5 +119,6 @@ ENV/
119119

120120
# End of https://www.gitignore.io/api/linux,python
121121

122-
trees
122+
#trees
123123
models
124+
log.csv

‎main.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ def main():
4040
print("Testing...")
4141
model = rntn.RNTN.load(args.model)
4242
test_trees = tr.load_trees(args.dataset)
43-
cost, correct, total = model.test(test_trees)
44-
accuracy = correct * 100.0 / total
43+
cost, result = model.test(test_trees)
44+
accuracy = 100.0 * result.trace() / result.sum()
4545
print("Cost = {:.2f}, Correct = {:.0f} / {:.0f}, Accuracy = {:.2f} %".format(
46-
cost, correct, total, accuracy))
46+
cost, result.trace(), result.sum(), accuracy))
4747
else:
4848
# Initialize the model
4949
model = rntn.RNTN(

‎rntn.py

+29-20
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ def fit(self, trees, export_filename='models/RNTN.pickle', verbose=False):
3838
with open("log.csv", "a", newline='') as csvfile:
3939
csvwriter = csv.writer(csvfile)
4040
fieldnames = ["Timestamp", "Vector size", "Learning rate",
41-
"Batch size", "Regularization", "Epoch", "Cost",
42-
"Accuracy"]
41+
"Batch size", "Regularization", "Epoch",
42+
"Train cost", "Train accuracy",
43+
"Test cost", "Test accuracy"]
4344
if csvfile.tell() == 0:
4445
csvwriter.writerow(fieldnames)
4546

@@ -53,13 +54,17 @@ def fit(self, trees, export_filename='models/RNTN.pickle', verbose=False):
5354
# Save the model
5455
self.save(export_filename)
5556

56-
# Test the model
57-
cost, correct, total = self.test(test_trees)
58-
accuracy = correct * 100.0 / total
57+
# Test the model on train and test set
58+
train_cost, train_result = self.test(trees)
59+
train_accuracy = 100.0 * train_result.trace() / train_result.sum()
60+
test_cost, test_result = self.test(test_trees)
61+
test_accuracy = 100.0 * test_result.trace() / test_result.sum()
5962

6063
# Append data to CSV file
6164
row = [datetime.now(), self.dim, self.learning_rate,
62-
self.batch_size, self.reg, epoch, cost, accuracy]
65+
self.batch_size, self.reg, epoch,
66+
train_cost, train_accuracy,
67+
test_cost, test_accuracy]
6368
csvwriter.writerow(row)
6469

6570
def test(self, trees):
@@ -145,18 +150,17 @@ def init_params(self):
145150
self.dbs = np.empty_like(self.bs)
146151

147152
def cost_and_grad(self, trees, test=False):
148-
cost, correct, total = 0.0, 0.0, 0.0
153+
cost, result = 0.0, np.zeros((5,5))
149154
self.L, self.V, self.W, self.b, self.Ws, self.bs = self.stack
150155

151156
# Forward propagation
152157
for tree in trees:
153-
_cost, _correct, _total = self.forward_prop(tree)
158+
_cost, _result = self.forward_prop(tree)
154159
cost += _cost
155-
correct += _correct
156-
total += _total
160+
result += _result
157161

158162
if test:
159-
return cost / len(trees), correct, total
163+
return cost / len(trees), result
160164

161165
# Initialize gradients
162166
self.dL = collections.defaultdict(lambda: np.zeros((self.dim,)))
@@ -191,7 +195,8 @@ def cost_and_grad(self, trees, test=False):
191195
return cost, grad
192196

193197
def forward_prop(self, tree):
194-
cost, correct, total = 0.0, 0.0, 0.0
198+
cost = 0.0
199+
result = np.zeros((5,5))
195200

196201
if tr.isleaf(tree):
197202
# output = word vector
@@ -202,11 +207,10 @@ def forward_prop(self, tree):
202207
tree.fprop = True
203208
else:
204209
# calculate output of child nodes
205-
lcost, lcorrect, ltotal = self.forward_prop(tree[0])
206-
rcost, rcorrect, rtotal = self.forward_prop(tree[1])
210+
lcost, lresult = self.forward_prop(tree[0])
211+
rcost, rresult = self.forward_prop(tree[1])
207212
cost += lcost + rcost
208-
correct += lcorrect + rcorrect
209-
total += ltotal + rtotal
213+
result += lresult + rresult
210214

211215
# compute output
212216
lr = np.hstack([tree[0].vector, tree[1].vector])
@@ -224,10 +228,11 @@ def forward_prop(self, tree):
224228

225229
# cost
226230
cost -= np.log(tree.output[int(tree.label())])
227-
correct += (np.argmax(tree.output) == int(tree.label()))
228-
total += 1
231+
true_label = int(tree.label())
232+
predicted_label = np.argmax(tree.output)
233+
result[true_label, predicted_label] += 1
229234

230-
return cost, correct, total
235+
return cost, result
231236

232237
def back_prop(self, tree, error=None):
233238
# clear nodes
@@ -245,7 +250,11 @@ def back_prop(self, tree, error=None):
245250

246251
# leaf node => update word vectors
247252
if tr.isleaf(tree):
248-
self.dL[self.word_map[tree[0]]] += deltas
253+
try:
254+
index = self.word_map[tree[0]]
255+
except KeyError:
256+
index = self.word_map[tr.UNK]
257+
self.dL[index] += deltas
249258
return
250259

251260
# Hidden gradients

‎test.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#!/bin/env python3
22

3-
import unittest
4-
53
import tree as tr
64

75

@@ -21,17 +19,13 @@ def f(model, text):
2119
>>> model = rntn.RNTN.load('models/RNTN.pickle')
2220
>>> f(model, "not very good")
2321
1
24-
-----|-----
22+
____|____
2523
| 4
26-
| ---|---
24+
| __|__
2725
2 2 3
2826
| | |
2927
not very good
3028
3129
"""
3230
for tree in tr.parse(text):
3331
model.predict(tree).pretty_print()
34-
35-
36-
if __name__ == '__main__':
37-
unittest.main()

‎train.sh

+74-17
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,79 @@
11
#!/bin/sh
22

3-
set -x
4-
5-
dim=25
6-
epochs=30
7-
learning_rate=1e-2
8-
batch_size=30
9-
optimizer="adagrad"
10-
reg=1e-6
3+
# The default values
4+
DEFAULT_DIM=25
5+
DEFAULT_EPOCH=10
6+
DEFAULT_LEARNING_RATE=1e-1
7+
DEFAULT_BATCH_SIZE=30
8+
DEFAULT_REG=1e-6
9+
10+
# Values to test
11+
DIM_LIST=( 10 20 25 30 40 50 )
12+
LEARNING_RATE_LIST=( 1e-4 1e-3 1e-2 1e-1 1 )
13+
BATCH_SIZE_LIST=( 1 10 30 50 70 100 )
14+
REG_LIST=( 1e-6 1e-4 1e-2 0 10 )
15+
16+
optimizer="adagrad" # This is like constant
17+
18+
# Tune vector size
19+
#------------------
20+
21+
epochs=$DEFAULT_EPOCH
22+
learning_rate=$DEFAULT_LEARNING_RATE
23+
batch_size=$DEFAULT_BATCH_SIZE
24+
reg=$DEFAULT_REG
1125
datetime=$(date +"%Y%m%d%H%M")
12-
dataset="train"
26+
for dim in "${DIM_LIST[@]}"; do
27+
outfile="models/RNTN_D${dim}_E${epochs}_B${batch_size}_L${learning_rate}_R${reg}_${optimizer}_${datetime}.pickle"
28+
set -x
29+
python3 main.py \
30+
--dim=${dim} \
31+
--epochs=${epochs} \
32+
--learning-rate=${learning_rate} \
33+
--batch-size=${batch_size} \
34+
--reg=${reg} \
35+
--model=${outfile}
36+
set +x
37+
done
1338

14-
outfile="models/RNTN_D${dim}_E${epochs}_B${batch_size}_L${learning_rate}_R${reg}_${optimizer}_${datetime}.pickle"
39+
# Tune batch size
40+
#---------------------
1541

16-
python3 main.py \
17-
--dim=${dim} \
18-
--epochs=${epochs} \
19-
--learning-rate=${learning_rate} \
20-
--batch-size=${batch_size} \
21-
--dataset=${dataset} \
22-
--model=${outfile}
42+
epochs=$DEFAULT_EPOCH
43+
learning_rate=$DEFAULT_LEARNING_RATE
44+
dim=$DEFAULT_DIM
45+
reg=$DEFAULT_REG
46+
datetime=$(date +"%Y%m%d%H%M")
47+
for batch_size in "${BATCH_SIZE_LIST[@]}"; do
48+
outfile="models/RNTN_D${dim}_E${epochs}_B${batch_size}_L${learning_rate}_R${reg}_${optimizer}_${datetime}.pickle"
49+
set -x
50+
python3 main.py \
51+
--dim=${dim} \
52+
--epochs=${epochs} \
53+
--learning-rate=${learning_rate} \
54+
--batch-size=${batch_size} \
55+
--reg=${reg} \
56+
--model=${outfile}
57+
set +x
58+
done
59+
60+
# Tune regularization parameter
61+
#-------------------------------
62+
63+
epochs=$DEFAULT_EPOCH
64+
learning_rate=$DEFAULT_LEARNING_RATE
65+
dim=$DEFAULT_DIM
66+
batch_size=$DEFAULT_BATCH_SIZE
67+
datetime=$(date +"%Y%m%d%H%M")
68+
for reg in "${REG_LIST[@]}"; do
69+
outfile="models/RNTN_D${dim}_E${epochs}_B${batch_size}_L${learning_rate}_R${reg}_${optimizer}_${datetime}.pickle"
70+
set -x
71+
python3 main.py \
72+
--dim=${dim} \
73+
--epochs=${epochs} \
74+
--learning-rate=${learning_rate} \
75+
--batch-size=${batch_size} \
76+
--reg=${reg} \
77+
--model=${outfile}
78+
set +x
79+
done

‎trees/dev.txt

+1,101
Large diffs are not rendered by default.

‎trees/test.txt

+2,210
Large diffs are not rendered by default.

‎trees/train.txt

+9,645
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)
Please sign in to comment.