-
-
Notifications
You must be signed in to change notification settings - Fork 541
/
Copy pathcli.py
59 lines (42 loc) · 1.88 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
from multiprocessing import Process
from lib.cli.RLTraderCLI import RLTraderCLI
from lib.util.logger import init_logger
from lib.cli.functions import download_data_async
from lib.env.reward import BaseRewardStrategy, IncrementalProfit, WeightedUnrealizedProfit
np.warnings.filterwarnings('ignore')
trader_cli = RLTraderCLI()
args = trader_cli.get_args()
rewards = {"incremental-profit": IncrementalProfit, "weighted-unrealized-profit": WeightedUnrealizedProfit}
reward_strategy = rewards[args.reward_strat]
def run_optimize(args, logger):
from lib.RLTrader import RLTrader
trader = RLTrader(**vars(args), logger=logger, reward_strategy=reward_strategy)
trader.optimize(n_trials=args.trials)
if __name__ == '__main__':
logger = init_logger(__name__, show_debug=args.debug)
if args.command == 'optimize':
n_processes = args.parallel_jobs
processes = []
for _ in range(n_processes):
processes.append(Process(target=run_optimize, args=(args, logger)))
for proc in processes:
proc.start()
for proc in processes:
proc.join()
from lib.RLTrader import RLTrader
trader = RLTrader(**vars(args), logger=logger, reward_strategy=reward_strategy)
if args.command == 'train':
trader.train(n_epochs=args.epochs,
save_every=args.save_every,
test_trained_model=args.test_trained,
render_test_env=args.render_test,
render_report=args.render_report,
save_report=args.save_report)
elif args.command == 'test':
trader.test(model_epoch=args.model_epoch,
render_env=args.render_env,
render_report=args.render_report,
save_report=args.save_report)
elif args.command == 'update-static-data':
download_data_async()