-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
executable file
·135 lines (116 loc) · 4.04 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/python3.8
import dataclasses
import json
import functools
import os
import re
import string
import time
import anthropic
import groq
import mistralai.client
import mistralai.models
# infra - checkpoint while generating results
@dataclasses.dataclass
class CheckPointer:
"""Context manager for checkpointing while you accumulate results.
"""
k_interval: int = 0
time_interval: float = 60.0
filestem: str = 'checkpt'
last_chkpt_k: int = 0
last_chkpt_time: float = 0.0
k: int = 0
results: list = dataclasses.field(default_factory=list)
def __enter__(self):
self.last_chkpt_time = time.time()
print(f'initialize Checkpointing: {self.k=}')
return self
def __exit__(self, _type, _value, _traceback):
self.flush(forced=True)
def _num_ready_to_flush(self, forced):
"""Return (n, flush_now) where n is number of unflushed examples,
and flush_now is true if that's enough to flush.
"""
num_unflushed = self.k - self.last_chkpt_k
if self.k_interval:
if self.k >= self.last_chkpt_k + self.k_interval:
return num_unflushed, True
if self.time_interval:
if time.time() >= self.last_chkpt_time + self.time_interval:
return num_unflushed, True
return num_unflushed, forced
def flush(self, forced=False):
"""Write out unflushed results to a file.
"""
num, flush_now = self._num_ready_to_flush(forced)
if not flush_now or num==0:
return
elapsed = time.time() - self.last_chkpt_time
hi = self.k
lo = self.k - num
file = f'{self.filestem}.{lo:04d}-{hi:04d}.jsonl'
with open(file, 'w') as fp:
for r in self.results[lo:hi]:
fp.write(json.dumps(r) + '\n')
print(f'write {file} with {num} outputs in {elapsed:.2f} sec')
self.last_chkpt_k = self.k
self.last_chkpt_time = time.time()
def append(self, result):
"""Append a new result to the running list."""
self.results.append(result)
self.k += 1
self.flush(forced=False)
# an LLM model
def llm(prompt, service='groq', model=None):
if service == 'groq':
if model is None: model='mixtral-8x7b-32768'
client = groq.Groq(api_key=os.environ.get('GROQ_API_KEY'))
completion = client.chat.completions.create(
messages = [dict(role='user', content=prompt)],
model=model)
return completion.choices[0].message.content
elif service == 'mistral':
if model is None: model='open-mixtral-8x7b'
client = mistralai.client.MistralClient(
api_key=os.environ.get('MISTRAL_API_KEY'))
chat_response = client.chat(
messages=[mistralai.models.chat_completion.ChatMessage(
role='user', content=prompt)],
model=model)
return chat_response.choices[0].message.content
elif service == 'anthropic':
if model is None: model='claude-3-haiku-20240307'
client = anthropic.Anthropic(
api_key=os.environ.get('ANTHROPIC_API_KEY'))
message = client.messages.create(
max_tokens=1024,
model=model,
messages=[{"role": "user", "content": prompt}])
return message.content[0].text
else:
raise ValueError(f'invalid service {service}')
# run evaluation with a model
def run_eval(data, prompt_fn, parse_fn, delay=0.0, service='groq', **kw):
with CheckPointer(**kw) as cp:
for d in data:
long_answer = llm(prompt_fn(d), service=service)
short_answer = parse_fn(long_answer)
d.update(ynu=short_answer, long_answer=long_answer)
cp.append(d)
time.sleep(delay)
return cp.results
# normalize answers
# taken from https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))