-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_script.py
63 lines (56 loc) · 2.16 KB
/
run_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import argparse
import datetime
from functools import partial
from multiprocessing.dummy import Pool
from subprocess import call
from utilities import get_time_string, get_log_dir, parse_time_string
DEFAULT_NUM_RUNS = 15
RUN_FILE = "framework.py"
LOG_DIR_ROOT = "logfiles"
DEFAULT_AGENT = "Qlearner"
DEFAULT_ENV = "CartPole-v0"
DEFAULT_NUM_ROLLOUTS = 600
DEFAULT_NUM_WORKERS = 7
DEFAULT_LEARNING_RATE=1e-3
parser = argparse.ArgumentParser()
parser.add_argument('--num_runs', type=int, default=DEFAULT_NUM_RUNS)
parser.add_argument('--agentname', type=str, nargs="*", default=[DEFAULT_AGENT])
parser.add_argument('--envname', type=str, nargs="*", default=[DEFAULT_ENV])
parser.add_argument('--render', action='store_true')
parser.add_argument("--max_timesteps", type=int)
parser.add_argument('--num_rollouts', type=int, default=DEFAULT_NUM_ROLLOUTS)
parser.add_argument('--learning_rate', type=float, default=DEFAULT_LEARNING_RATE)
parser.add_argument('--n_hiddens', nargs="+", type=int, default=[8])
parser.add_argument('--log_dir', type=str, default=LOG_DIR_ROOT)
parser.add_argument('--log_tf', action='store_true')
parser.add_argument('--num_workers', type=int, default=DEFAULT_NUM_WORKERS)
args = parser.parse_args()
start_time = get_time_string()
start_datetime = parse_time_string(start_time)
log_tf = "--no_tf_log"
if args.log_tf:
log_tf = ""
try:
os.mkdir(args.log_dir)
except:
pass
commands = []
for env in args.envname:
for agent in args.agentname:
for i in range(args.num_runs):
commands.append(
"python {0} {1} {2} --log_dir_root={3} --num_rollouts={4} {5} --learning_rate={6} --n_hiddens {7}".format(
RUN_FILE,
agent,
env,
args.log_dir,
args.num_rollouts,
log_tf,
args.learning_rate,
" ".join(str(x) for x in args.n_hiddens)))
pool = Pool(args.num_workers) # two concurrent commands at a time
for i, returncode in enumerate(pool.imap(partial(call, shell=True), commands)):
if returncode != 0:
print("%d command failed: %d" % (i, returncode))
print("Success")