-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
104 lines (90 loc) · 4.05 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gymnasium as gym
from config.config import Config
import argparse
import numpy as np
import torch
import gc
import logging
from logger.logger import Logger
from utils.tools import *
import time
import sys, os
sys.path.append(".")
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from runner.multi_evo_agent_runner import MultiEvoAgentRunner
from runner.multi_agent_runner import MultiAgentRunner
from runner.selfplay_agent_runner import SPAgentRunner
def main():
# ----------------------------------------------------------------------------#
# Load config options from terminal and predefined yaml file
# ----------------------------------------------------------------------------#
parser = argparse.ArgumentParser(description="User's arguments from terminal.")
parser.add_argument("--cfg",
dest="cfg_file",
help="Config file",
required=True,
type=str)
parser.add_argument('--use_cuda', type=str2bool, default=True)
parser.add_argument('--gpu_index', type=int, default=0)
parser.add_argument('--num_threads', type=int, default=1)
parser.add_argument('--ckpt_dir', type=str, default=None)
parser.add_argument('--ckpt', type=str, default='0')
args = parser.parse_args()
# Load config file
cfg = Config(args.cfg_file)
# ----------------------------------------------------------------------------#
# Define logger and create dirs
# ----------------------------------------------------------------------------#
logger = Logger(name='current', cfg=cfg)
logger.propagate = False
logger.setLevel(logging.INFO)
# set output
logger.set_output_handler()
logger.print_system_info()
# only training generates log file
logger.critical("The current environment is {}.".format(cfg.env_name))
logger.info("Running directory: {}".format(logger.run_dir))
logger.info('Type of current running: Training')
logger.set_file_handler()
# Save the config file
cfg.save_config(logger.run_dir)
# ----------------------------------------------------------------------------#
# Set torch and random seed
# ----------------------------------------------------------------------------#
dtype = torch.float64
torch.set_default_dtype(dtype)
device = torch.device('cuda', index=args.gpu_index) \
if args.use_cuda and torch.cuda.is_available() else torch.device('cpu')
# torch.cuda.is_available() is natively False on mac m1
if torch.cuda.is_available():
torch.cuda.set_device(args.gpu_index)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
# ----------------------------------------------------------------------------#
# Training
# ----------------------------------------------------------------------------#
# runner definition
# runner = MultiEvoAgentRunner(cfg, logger, dtype, device,
# num_threads=args.num_threads, training=True)
ckpt = int(args.ckpt) if args.ckpt.isnumeric() else args.ckpt
if cfg.runner_type == "multi-agent-runner":
ckpt = [ckpt] * 2
runner = MultiAgentRunner(cfg, logger, dtype, device,
num_threads=args.num_threads, training=True, ckpt_dir=args.ckpt_dir, ckpt=ckpt)
elif cfg.runner_type == "selfplay-agent-runner":
runner = SPAgentRunner(cfg, logger, dtype, device,
num_threads=args.num_threads, training=True, ckpt=ckpt)
elif cfg.runner_type == "multi-evo-agent-runner":
ckpt = [ckpt] * 2
runner = MultiEvoAgentRunner(cfg, logger, dtype, device,
num_threads=args.num_threads, training=True, ckpt_dir=args.ckpt_dir, ckpt=ckpt)
# main loop
for epoch in range(0, cfg.max_epoch_num):
runner.optimize(epoch)
runner.save_checkpoint(epoch)
"""clean up gpu memory"""
gc.collect()
torch.cuda.empty_cache()
runner.logger.info('training done!')
if __name__ == "__main__":
main()