-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample.py
100 lines (84 loc) · 3.26 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import sys
import logging
import traceback
from argparse import ArgumentParser, Namespace
from concurrent.futures import ProcessPoolExecutor, as_completed
import jsonlines
from tqdm import tqdm
from search_llm_actions import (
run_shell_script, run_check_vllm,
Search, Node,
VLLMCaller, TogetherCaller, try_till_success
)
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sys.path.append(current_dir)
def get_logger(name: str, log_file: str = "main.log") -> logging.Logger:
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh = logging.FileHandler(log_file)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
def get_args() -> Namespace:
parser = ArgumentParser()
parser.add_argument("--output_file", type=str, default="search_res.jsonl")
parser.add_argument("--max_iter", type=int, default=8)
parser.add_argument("--expansion_width", type=int, default=2)
parser.add_argument("--max_depth", type=int, default=5)
parser.add_argument("--worker_num", type=int, default=4)
parser.add_argument("--mini_batch_size", type=int, default=8)
return parser.parse_args()
def main():
logger = get_logger(__name__)
logger.info("### Start search ###")
args = get_args()
# Set llm_caller
args.llm_caller = TogetherCaller()
# # deploy model if it's not deployed
# run_check_vllm()
# Get infos from data
infos = [
{
"tags": {"tag": f"tag{_}"},
"something": f"something{_}",
}
for _ in range(16)
]
td = tqdm(total=len(infos), desc="Total", position=0)
suc_count = 0
for batch_start in range(0, len(infos), args.mini_batch_size):
batch_idx = batch_start // args.mini_batch_size
with ProcessPoolExecutor(max_workers=args.worker_num) as executor:
futures = []
for idx in range(
batch_start, min(batch_start + args.mini_batch_size, len(infos))
):
args.tags = infos[idx]["tags"]
search = Search.from_args(args, infos[idx])
future = executor.submit(search.search, td_position=(idx-batch_start+1))
futures.append(future)
for future in as_completed(futures):
if future.exception() is not None:
logger.error(future.exception())
logger.error(traceback.format_exc())
else:
reward_ave, dic_res = future.result()
with jsonlines.open(args.output_file, mode="a") as writer:
writer.write(dic_res)
td.update(1)
suc_count += any([
node["node_info"]["reward_sum"] == 1.0
for node in dic_res["nodes"].values()
if node["node_info"]["end_flag"]
])
td.set_postfix({
"batch": batch_idx,
"suc_count": suc_count,
})
logger.info("### Finish search ###")
if __name__ == "__main__":
main()