-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_aime.py
More file actions
132 lines (114 loc) · 3.67 KB
/
run_aime.py
File metadata and controls
132 lines (114 loc) · 3.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import asyncio
import json
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from datasets import load_dataset
from openai.types.chat import ChatCompletion
from tqdm import tqdm
from deepthink.agent import Agent
from deepthink.llm import LLM
from deepthink.registry import get_agent_cls, list_agents
def parse_args():
parser = ArgumentParser(
prog=f"uv run {os.path.basename(__file__)}",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument("model", type=str, help="base model to use")
parser.add_argument(
"strategy",
type=str,
help="inference strategy to use",
choices=list_agents(),
)
parser.add_argument(
"--base_url",
type=str,
help="openai-api-compatible api endpoint",
default="https://openrouter.ai/api/v1",
)
parser.add_argument(
"--api_key", type=str, help="api key", default=os.environ["OPENROUTER_API_KEY"]
)
parser.add_argument(
"--n",
type=int,
help="number of samples per question",
default=1,
)
parser.add_argument(
"--extra_body",
type=json.loads,
help="extra body for openai request",
default={},
)
parser.add_argument("--output", type=str, help="output file", required=False)
args, extra_args = parser.parse_known_args()
return args, parse_kwargs(extra_args)
def auto_cast(val):
if val.lower() in {"true", "false"}:
return val.lower() == "true"
try:
return int(val)
except ValueError:
try:
return float(val)
except ValueError:
return val
def parse_kwargs(extra_args):
kwargs = {}
i = 0
while i < len(extra_args):
arg = extra_args[i]
if arg.startswith("--"):
key = arg[2:].replace("-", "_")
if i + 1 < len(extra_args) and not extra_args[i + 1].startswith("--"):
value = auto_cast(extra_args[i + 1])
i += 2
else:
value = True
i += 1
kwargs[key] = value
else:
i += 1
return kwargs
async def main():
args, extra_args = parse_args()
dataset = load_dataset("yentinglin/aime_2025", split="train")
llm = LLM(
model=args.model,
base_url=args.base_url,
api_key=args.api_key,
extra_body=args.extra_body,
)
agent_cls = get_agent_cls(args.strategy)
bar = tqdm(desc="Generation progress", total=len(dataset) * args.n)
async def generate(row):
agent: Agent = agent_cls(llm=llm)
result = await agent.run(
prompt=f"Please put final answer in a \\boxed{{}} environment.\n{row['problem']}",
**extra_args,
)
bar.update(1)
return agent.calls, result
tasks = [generate(row) for row in dataset for _ in range(args.n)]
responses: list[tuple[list[ChatCompletion], ChatCompletion]] = await asyncio.gather(
*tasks
)
with open(
args.output
if args.output
else f"{args.model.split('/')[-1]}-{args.strategy}.jsonl",
"a+",
) as f:
for (completions, final_response), row in zip(
responses, [row for row in dataset for _ in range(args.n)], strict=True
):
eval_entry = {}
eval_entry["final_response"] = final_response.model_dump()
eval_entry["completions"] = completions
eval_entry["problem"] = row["problem"]
eval_entry["answer"] = row["answer"]
eval_entry["dataset"] = "aime2025"
f.write(json.dumps(eval_entry) + "\n")
if __name__ == "__main__":
asyncio.run(main())