-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmwp.py
115 lines (94 loc) · 4.38 KB
/
mwp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import re
from prompts.base import ExampleTemplate
gsm8k_prefix = "Answer the following question through careful, concise step-by-step reasoning."
class GSM8KExampleTemplate(ExampleTemplate):
input_variables: list[str] = ['question', 'answer']
templates: dict[str, str] = dict(
train="Question: {source}\nSolution: {target}",
test="Question: {source}\nSolution: ",
)
def format(self, test=False, embedding=False, **kwargs):
source = self.get_source(**kwargs)
target = self.get_target(**kwargs)
if embedding: return source
template = self.templates['test'] if test else self.templates['train']
return template.format(source=source, target=target)
def prepare_for_turbo(self, string):
input = string[:string.find("Solution: ")].strip()
output = string[string.find("Solution: "):].strip()
input = input.replace("Question: ", "")
output = output.replace("Solution: ", "")
return input, output
def get_source(self, **kwargs):
return kwargs[self.input_variables[0]].strip()
def get_target(self, **kwargs):
return kwargs[self.input_variables[1]].strip()
def extract_answer(self, completion):
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"
match = ANS_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return match_str
else:
# special case for turbo
TURBO_ANS_RE = re.compile(r"Answer: \\boxed{(\-?[0-9\.\,]+)}")
match = TURBO_ANS_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return match_str
else:
return INVALID_ANS
def parse_output(self, lm_output: str, **kwargs):
# based on https://github.com/openai/grade-school-math/blob/3101c7d5072418e28b9008a6636bde82a006892c/grade_school_math/dataset.py#L28
return self.extract_answer(lm_output)
def check_output(self, pred, target, **kwargs):
answer = self.extract_answer(target)
is_correct = pred == answer
return dict(accuracy=is_correct * 100,)
aqua_prefix = "Answer the following question through careful, concise step-by-step reasoning." # taken from cot/fragments.json
class AquaExampleTemplate(ExampleTemplate):
input_variables: list[str] = ['question', 'options', 'rationale', 'correct']
templates: dict[str, str] = dict(
train="Question: {source}\nChoices: {choices}\nRationale: {rationale}\nAnswer: {target}",
test="Question: {source}\nChoices: {choices}\nRationale: ",
)
def format(self, test=False, embedding=False, **kwargs):
source = self.get_source(**kwargs)
target = self.get_target(**kwargs)
if embedding: return source
template = self.templates['test'] if test else self.templates['train']
return template.format(
source=source, target=target,
choices=self.get_choices(**kwargs),
rationale=self.get_rationale(**kwargs),
)
def get_source(self, **kwargs):
return kwargs[self.input_variables[0]].strip()
def get_choices(self, **kwargs):
return ', '.join(kwargs[self.input_variables[1]])
def get_rationale(self, **kwargs):
return kwargs[self.input_variables[2]].strip()
def get_target(self, **kwargs):
return kwargs[self.input_variables[3]].strip()
def extract_answer(self, completion):
ANS_RE = re.compile(r"Answer: ([A-E])")
INVALID_ANS = "[invalid]"
match = ANS_RE.search(completion)
if match:
match_str = match.group(1).strip()
return match_str
else:
return INVALID_ANS
def parse_output(self, lm_output: str, **kwargs):
# based on https://github.com/openai/grade-school-math/blob/3101c7d5072418e28b9008a6636bde82a006892c/grade_school_math/dataset.py#L28
return self.extract_answer(lm_output)
# def check_output(self, lm_output, **kwargs):
# prediction = self.extract_answer(lm_output)
# answer = self.get_target(**kwargs)
# return prediction == answer
def check_output(self, pred, target, **kwargs):
is_correct = pred == target
return dict(accuracy=is_correct * 100,)