From a2b86107cfcab3b8468dfbfc5faeba65330b0aa4 Mon Sep 17 00:00:00 2001 From: adamjking3 Date: Sun, 7 Jul 2019 23:49:57 -0700 Subject: [PATCH] Fix TF hanging issue in multiprocessing pools. --- cli.py | 32 +++++++++++++++----------------- optimize.py | 19 +++++++------------ 2 files changed, 22 insertions(+), 29 deletions(-) diff --git a/cli.py b/cli.py index 4d11ab2..920f664 100644 --- a/cli.py +++ b/cli.py @@ -1,5 +1,6 @@ import numpy as np -import multiprocessing + +from multiprocessing.pool import ThreadPool from lib.RLTrader import RLTrader from lib.cli.RLTraderCLI import RLTraderCLI @@ -11,32 +12,29 @@ args = trader_cli.get_args() -def run_concurrent_optimize(): - trader = RLTrader(**vars(args)) - trader.optimize(args.trials) - +def run_optimize(params): + trader_args, logger = params -def concurrent_optimize(): - processes = [] - for i in range(args.parallel_jobs): - processes.append(multiprocessing.Process(target=run_concurrent_optimize, args=())) + trader = RLTrader(**vars(trader_args), logger=logger) + trader.optimize(trader_args.trials) - print(processes) - for p in processes: - p.start() +def optimize_concurrent(trader_args, logger): + n_processes = trader_args.parallel_jobs - for p in processes: - p.join() + opt_pool = ThreadPool(processes=n_processes) + opt_pool.map(run_optimize, [((trader_args, logger)) for _ in range(n_processes)]) if __name__ == '__main__': logger = init_logger(__name__, show_debug=args.debug) - trader = RLTrader(**vars(args), logger=logger) if args.command == 'optimize': - concurrent_optimize() - elif args.command == 'train': + optimize_concurrent(args, logger) + + trader = RLTrader(**vars(args), logger=logger) + + if args.command == 'train': trader.train(n_epochs=args.epochs) elif args.command == 'test': trader.test(model_epoch=args.model_epoch, should_render=args.no_render) diff --git a/optimize.py b/optimize.py index fb5634b..b6d4154 100644 --- a/optimize.py +++ b/optimize.py @@ -1,6 +1,8 @@ -import multiprocessing +import os import numpy as np +from multiprocessing.pool import ThreadPool + from lib.RLTrader import RLTrader np.warnings.filterwarnings('ignore') @@ -12,18 +14,11 @@ def optimize_code(params): if __name__ == '__main__': - n_process = multiprocessing.cpu_count() - params = {'n_envs': n_process} - - processes = [] - for i in range(n_process): - processes.append(multiprocessing.Process(target=optimize_code, args=(params,))) - - for p in processes: - p.start() + n_processes = 6 # os.cpu_count() + params = {'n_envs': n_processes} - for p in processes: - p.join() + opt_pool = ThreadPool(processes=n_processes) + opt_pool.map(optimize_code, [params for _ in range(n_processes)]) trader = RLTrader(**params) trader.train(test_trained_model=True, render_trained_model=True)