-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathrun.py
executable file
·70 lines (55 loc) · 1.83 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python
import os
import _jsonnet
import json
import argparse
import attr
import wandb
from experiments.spider_dg import (
train,
meta_train,
)
@attr.s
class TrainConfig:
config = attr.ib()
config_args = attr.ib()
logdir = attr.ib()
@attr.s
class MetaTrainConfig:
config = attr.ib()
config_args = attr.ib()
logdir = attr.ib()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"mode", choices=["train", "meta_train",], help="train/meta_train/dist_train",
)
parser.add_argument("exp_config_file", help="jsonnet file for experiments")
args = parser.parse_args()
exp_config = json.loads(_jsonnet.evaluate_file(args.exp_config_file))
model_config_file = exp_config["model_config"]
if "model_config_args" in exp_config:
model_config_args = json.dumps(exp_config["model_config_args"])
else:
model_config_args = None
# cluster base dir
log_base_dir = os.environ.get("LOG_BASE_DIR", None)
if log_base_dir is None:
print(f"Using default log base dir {os.getcwd()}")
logdir = exp_config["logdir"]
else:
logdir = os.path.join(log_base_dir, exp_config["logdir"])
# wandb init
expname = exp_config["logdir"].split("/")[-1]
project = exp_config["project"]
# dist train need to start a wandb session in each process, not a global one
if args.mode in ["train", "meta_train"]:
wandb.init(project=project, group=expname, job_type=args.mode)
if args.mode == "train":
train_config = TrainConfig(model_config_file, model_config_args, logdir)
train.main(train_config)
elif args.mode == "meta_train":
train_config = MetaTrainConfig(model_config_file, model_config_args, logdir)
meta_train.main(train_config)
if __name__ == "__main__":
main()