forked from se4u/neural_wfst
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathEvalHitsAtK.py
126 lines (108 loc) · 5.19 KB
/
EvalHitsAtK.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Copyright (C) 2017-2018 University of Massachusetts Amherst.
This file is part of "learned-string-alignments"
http://github.com/iesl/learned-string-alignments
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import sys
import numpy as np
def eval_hits_at_k(list_of_list_of_labels,
list_of_list_of_scores,
k=10,
randomize=True,
oracle=False,
):
"""Compute Hits at K
Given a two lists with one element per test example compute the
mean average precision score.
The i^th element of each list is an array of scores or labels corresponding
to the i^th training example.
All scores are SIMILARITIES.
:param list_of_list_of_labels: Binary relevance labels. One list per example.
:param list_of_list_of_scores: Predicted relevance scores. One list per example.
:param k: the number of elements to consider
:param randomize: whether to randomize the ordering
:param oracle: break ties using the labels
:return: the mean average precision
"""
np.random.seed(19)
assert len(list_of_list_of_labels) == len(list_of_list_of_scores)
aps = []
for i in range(len(list_of_list_of_labels)):
if randomize == True:
perm = np.random.permutation(len(list_of_list_of_labels[i]))
list_of_list_of_labels[i] = list(np.asarray(list_of_list_of_labels[i])[perm])
list_of_list_of_scores[i] = list(np.asarray(list_of_list_of_scores[i])[perm])
if oracle:
zpd = zip(list_of_list_of_scores[i],list_of_list_of_labels[i])
sorted_zpd =sorted(zpd, reverse=True)
list_of_list_of_labels[i] = [x[1] for x in sorted_zpd]
list_of_list_of_scores[i] = [x[0] for x in sorted_zpd]
else:
zpd = zip(list_of_list_of_scores[i],list_of_list_of_labels[i])
sorted_zpd =sorted(zpd, key=lambda x: x[0], reverse=True)
list_of_list_of_labels[i] = [x[1] for x in sorted_zpd]
list_of_list_of_scores[i] = [x[0] for x in sorted_zpd]
# print("Labels: {}".format(list_of_list_of_labels[i]))
# print("Scores: {}".format(list_of_list_of_scores[i]))
labels_topk = list_of_list_of_labels[i][0:k]
# print("labels_topk: {}".format(labels_topk))
if sum(list_of_list_of_labels[i]) > 0:
hits_at_k = sum(labels_topk) * 1.0 / min(k, sum(list_of_list_of_labels[i]))
# print("Hits@{}: {}".format(k,hits_at_k))
aps.append(hits_at_k)
return sum(aps) / len(aps)
def load(prediction_filename, true_label_filename):
"""Load the labels and scores for MAP evaluation.
Loads labels and model predictions from files of the format:
Query \t Example \t Label \t Score
:param filename: Filename to load.
:return: list_of_list_of_labels, list_of_list_of_scores
"""
result_labels = []
result_scores = []
current_block_name = ""
current_block_scores = []
current_block_labels = []
with open(prediction_filename,'r') as prediction, open(true_label_filename, 'r') as label:
prediction_lines = prediction.readlines()
label_lines = label.readlines()
for i in range(len(prediction_lines)):
pred_line = prediction_lines[i]
label_line = label_lines[i]
pred_split = pred_line.strip().split("\t")
label_split = label_line.strip().split("\t")
block_name = pred_split[0]
block_example = pred_split[1]
example_label = int(label_split[2])
example_score = float(pred_split[2])
if block_name != current_block_name and current_block_name != "":
result_labels.append(current_block_labels)
result_scores.append(current_block_scores)
current_block_labels = []
current_block_scores = []
current_block_labels.append(example_label)
current_block_scores.append(example_score)
current_block_name = block_name
result_labels.append(current_block_labels)
result_scores.append(current_block_scores)
return result_labels,result_scores
def eval_hits_at_k_file(prediction_filename, true_label_filename,k=2,oracle=False):
list_of_list_of_labels,list_of_list_of_scores = load(prediction_filename, true_label_filename)
return eval_hits_at_k(list_of_list_of_labels,list_of_list_of_scores,k=k,oracle=oracle)
if __name__ == "__main__":
"""
Usage: filename [k=1] [oracle=False]
"""
filename = sys.argv[1]
k = int(sys.argv[2]) if len(sys.argv) > 2 else 1
oracle = sys.argv[3] == "True" if len(sys.argv) > 3 else False
# print("{}\t{}\t{}\t{}".format(filename,k,oracle, eval_hits_at_k_file(filename,k=k,oracle=oracle)))