-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
83 lines (62 loc) · 2.43 KB
/
utils.py
File metadata and controls
83 lines (62 loc) · 2.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Helper functions for evaluation."""
import collections
import re
import string
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace.
This function is an extended version of the SQuAD evaluation script.
Arguments:
s: ``str`` String to normalize.
"""
def remove_articles(text):
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
# Replace dash with a space
text = text.replace("-", " ")
# Replace other punctuation with empty string
for punc in string.punctuation:
text = text.replace(punc, "")
return text
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()
def em(ans, pred):
"""Implements the exact match metric."""
return int(normalize_answer(ans) == normalize_answer(pred))
def f1(ans, pred):
"""Implements the F1 metric."""
ans_tokens = get_tokens(ans)
pred_tokens = get_tokens(pred)
common = collections.Counter(ans_tokens) & collections.Counter(pred_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_tokens)
recall = 1.0 * num_same / len(ans_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def get_subset_scores(amber_sets, raw_metrics, head_subset: bool):
raw_subset_metrics = collections.defaultdict(list)
for amber_set in amber_sets:
for qid in amber_set["qids"]:
if amber_set["qids"][qid]["is_head"] == head_subset:
for query_dict in amber_set["qids"][qid]["queries"]:
query_id = query_dict["id"]
for metric in raw_metrics:
# This statement is because entity confusion is not
# computed over every query.
if query_id in raw_metrics[metric]:
raw_subset_metrics[metric].append(
raw_metrics[metric][query_id]
)
return {
metric: 100 * sum(raw_subset_metrics[metric]) / len(raw_subset_metrics[metric])
for metric in raw_subset_metrics
}