-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtbuie.py
executable file
·181 lines (144 loc) · 6.59 KB
/
tbuie.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#!/usr/bin/python3
import json
import flask
import random
import os
import ankura
import time
import pickle
from tqdm import tqdm
import sys
import tempfile
import threading
import argparse
class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
pass
parser=argparse.ArgumentParser(
description='Used for hosting tbuie with a given dataset',
epilog=('See https://github.com/byu-aml-lab/tbuie\n' +
' and https://github.com/byu-aml-lab/ankura/tree/ankura2/ankura\n' +
' for source and dependencies\n \n'),
formatter_class=CustomFormatter)
parser.add_argument('dataset', metavar='dataset',
choices=['newsgroups', 'yelp', 'tripadvisor', 'amazon'],
help='The name of a dataset to use in this instance of tbuie')
parser.add_argument('port', nargs='?', default=5000, type=int,
help='Port to be used in hosting the webpage')
args=parser.parse_args()
dataset_name = args.dataset
port = args.port
app = flask.Flask(__name__, static_url_path='')
user_data = list()
dev_size = 2500
number_of_topics = 20
label_weight = 1
smoothing = 0
prior_attr_name = 'lambda' #Attr for the prior probs (prob of each label)
if dataset_name == 'newsgroups':
attr_name = 'coarse_newsgroup'
corpus = ankura.corpus.newsgroups()
elif dataset_name == 'yelp':
attr_name = 'binary_rating'
corpus = ankura.corpus.yelp()
elif dataset_name == 'tripadvisor':
attr_name = 'label'
corpus = ankura.corpus.tripadvisor()
elif dataset_name == 'amazon':
attr_name = 'binary_rating'
corpus = ankura.corpus.amazon()
def calculate_user_data_accuracy(user_data, Q, test_corpus, train_dev_corpus, attr_name):
for i, data in enumerate(user_data):
anchor_tokens, anchor_vectors, accuracy = data
lr_accuracy = ankura.validate.anchor_accuracy(Q, anchor_vectors, test_corpus, train_dev_corpus, attr_name)
print('Instance', i, 'Free Classifier Accuracy:', accuracy, 'Logistic Regression Accuracy:', lr_accuracy)
@ankura.util.pickle_cache(dataset_name + '.pickle')
def load_data():
print('Splitting train/dev and test...')
# 80/20 split into test and train
split = ankura.pipeline.train_test_split(corpus, return_ids=True)
(train_dev_ids, train_dev_corpus), (test_ids, test_corpus) = split
train_dev_size = len(train_dev_ids)
print(f' train/dev size: {train_dev_size}')
print(f' test size: {len(test_ids)}')
train_size = train_dev_size - dev_size
print('Splitting train and dev...')
# Second split to give train and dev sets
split = ankura.pipeline.train_test_split(train_dev_corpus, num_train=train_size,
num_test=dev_size, remove_testonly_words=False,
return_ids=True)
(train_ids, train_corpus), (dev_ids, dev_corpus) = split
print(f' train size: {train_size}')
print(f' dev size: {dev_size}')
Q, labels = ankura.anchor.build_labeled_cooccurrence(train_dev_corpus, attr_name,
range(len(train_dev_corpus.documents)), # All are labeled
label_weight=label_weight, smoothing=smoothing)
gs_anchor_indices = ankura.anchor.gram_schmidt_anchors(train_dev_corpus, Q, k=number_of_topics, return_indices=True)
gs_anchor_vectors = Q[gs_anchor_indices]
gs_anchor_tokens = [[train_dev_corpus.vocabulary[index]] for index in gs_anchor_indices]
#This is memory inefficient, since we never use train_corpus.
return (Q, labels, train_dev_ids, train_dev_corpus,
train_ids, train_corpus, dev_corpus, dev_ids,
test_ids, test_corpus, gs_anchor_vectors,
gs_anchor_indices, gs_anchor_tokens)
# Load the data (will load from pickle if it can)
(Q, labels, train_dev_ids, train_dev_corpus,
train_ids, train_corpus, dev_corpus, dev_ids,
test_ids, test_corpus, gs_anchor_vectors,
gs_anchor_indices, gs_anchor_tokens) = load_data()
@app.route('/')
def serve_itm():
return app.send_static_file('index.html')
@app.route('/vocab')
def get_vocab():
return flask.jsonify(vocab=train_dev_corpus.vocabulary)
@app.route('/finished', methods=['GET', 'POST'])
def finish():
directory = os.path.join('FinalAnchors', dataset_name)
try:
os.makedirs(directory)
except FileExistsError:
pass
pickle.dump(user_data, tempfile.NamedTemporaryFile(mode='wb',
delete=False,
prefix=dataset_name,
suffix='.pickle',
dir=directory,
))
return 'OK'
@app.route('/topics')
def topic_request():
raw_anchors = flask.request.args.get('anchors')
start=time.time()
if raw_anchors is None:
anchor_tokens, anchor_vectors = gs_anchor_tokens, gs_anchor_vectors
else:
anchor_tokens = json.loads(raw_anchors)
anchor_vectors = ankura.anchor.tandem_anchors(anchor_tokens, Q,
train_dev_corpus, epsilon=1e-15)
print('***Time - tandem_anchors:', time.time()-start)
start=time.time()
C, topics = ankura.anchor.recover_topics(Q, anchor_vectors, epsilon=1e-5, get_c=True)
print('***Time - recover_topics:', time.time()-start)
start=time.time()
topic_summary = ankura.topic.topic_summary(topics[:len(train_dev_corpus.vocabulary)], train_dev_corpus)
print('***Time - topic_summary:', time.time()-start)
start=time.time()
classifier = ankura.topic.free_classifier_dream(train_dev_corpus, attr_name,
labeled_docs=set(train_ids), topics=topics,
C=C, labels=labels,
prior_attr_name=prior_attr_name)
print('***Time - Get Classifier:', time.time()-start)
contingency = ankura.validate.Contingency()
start=time.time()
for doc in dev_corpus.documents:
gold = doc.metadata[attr_name]
pred = classifier(doc)
contingency[gold, pred] += 1
print('***Time - Classify:', time.time()-start)
print('***Accuracy:', contingency.accuracy())
user_data.append((anchor_tokens, anchor_vectors, contingency.accuracy()))
return flask.jsonify(anchors=anchor_tokens,
topics=topic_summary,
accuracy=contingency.accuracy())
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=port)