Skip to content

Commit

Permalink
Update CLI params to better match RLTrader API
Browse files Browse the repository at this point in the history
  • Loading branch information
notadamking committed Jul 10, 2019
1 parent adf0ef6 commit a932a9d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 28 deletions.
14 changes: 11 additions & 3 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def run_optimize(args, logger):
from lib.RLTrader import RLTrader

trader = RLTrader(**vars(args), logger=logger)
trader.optimize(args.trials)
trader.optimize(n_trials=args.trials, n_prune_evals_per_trial=args.prune_evals, n_tests_per_eval=args.eval_tests)


if __name__ == '__main__':
Expand All @@ -39,8 +39,16 @@ def run_optimize(args, logger):
trader = RLTrader(**vars(args), logger=logger)

if args.command == 'train':
trader.train(n_epochs=args.epochs)
trader.train(n_epochs=args.epochs,
save_every=args.save_every,
test_trained_model=args.test_trained,
render_test_env=args.render_test,
render_report=args.render_report,
save_report=args.save_report)
elif args.command == 'test':
trader.test(model_epoch=args.model_epoch, should_render=args.no_render, render_tearsheet=args.no_tearsheet)
trader.test(model_epoch=args.model_epoch,
render_env=args.render_env,
render_report=args.render_report,
save_report=args.save_report)
elif args.command == 'update-static-data':
download_data_async()
62 changes: 37 additions & 25 deletions lib/cli/RLTraderCLI.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class RLTraderCLI:
def __init__(self):
config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument("-f", "--from-config", help="Specify config file", metavar="FILE")

args, _ = config_parser.parse_known_args()
defaults = {}

Expand All @@ -17,45 +18,56 @@ def __init__(self):
defaults = dict(config.items("Defaults"))

formatter = argparse.ArgumentDefaultsHelpFormatter
self.parser = argparse.ArgumentParser(
formatter_class=formatter,
parents=[config_parser],
description=__doc__
)
self.parser = argparse.ArgumentParser(formatter_class=formatter,
parents=[config_parser],
description=__doc__)

self.parser.add_argument("--data-provider", "-o", type=str, default="static")
self.parser.add_argument("--input-data-path", "-t", type=str, default="data/input/coinbase-1h-btc-usd.csv")
self.parser.add_argument("--data-provider", "-d", type=str, default="static")
self.parser.add_argument("--input-data-path", "-n", type=str, default="data/input/coinbase-1h-btc-usd.csv")
self.parser.add_argument("--pair", "-p", type=str, default="BTC/USD")
self.parser.add_argument("--debug", "-n", action='store_false')
self.parser.add_argument("--debug", "-D", action='store_false')
self.parser.add_argument('--mini-batches', type=int, default=1, help='Mini batches', dest='n_minibatches')
self.parser.add_argument('--train-split-percentage', type=float, default=0.8, help='Train set percentage')
self.parser.add_argument('--verbose-model', type=int, default=1, help='Verbose model')
self.parser.add_argument('--params-db-path', type=str, default='sqlite:///data/params.db',
help='Params path')
self.parser.add_argument(
'--tensor-board-path',
type=str,
default=os.path.join('data', 'tensorboard'),
help='Tensorboard path',
dest='tensorboard_path'
)
self.parser.add_argument('--parallel-jobs', type=int, default=multiprocessing.cpu_count(),
self.parser.add_argument('--verbose-model', type=int, default=1, help='Verbose model', dest='model_verbose')
self.parser.add_argument('--params-db-path', type=str, default='sqlite:///data/params.db', help='Params path')
self.parser.add_argument('--tensorboard-path',
type=str,
default=os.path.join('data', 'tensorboard'),
help='Tensorboard path')
self.parser.add_argument('--parallel-jobs',
type=int,
default=multiprocessing.cpu_count(),
help='How many processes in parallel')

subparsers = self.parser.add_subparsers(help='Command', dest="command")

optimize_parser = subparsers.add_parser('optimize', description='Optimize model parameters')
optimize_parser.add_argument('--trials', type=int, default=1, help='Number of trials')

optimize_parser.add_argument('--verbose-model', type=int, default=1, help='Verbose model', dest='model_verbose')
optimize_parser.add_argument('--prune-evals',
type=int,
default=2,
help='Number of pruning evaluations per trial')
optimize_parser.add_argument('--eval-tests', type=int, default=1, help='Number of tests per pruning evaluation')

train_parser = subparsers.add_parser('train', description='Train model')
train_parser.add_argument('--epochs', type=int, default=1, help='Number of epochs to train')
train_parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train')
train_parser.add_argument('--save-every', type=int, default=1, help='Save the trained model every n epochs')
train_parser.add_argument('--no-test', dest="test_trained", action="store_false", help='Test each saved model')
train_parser.add_argument('--render-test', dest="render_test",
action="store_true", help='Render the test environment')
train_parser.add_argument('--no-report', dest="render_report", action="store_false",
help='Render the performance report')
train_parser.add_argument('--save-report', dest="save_report", action="store_true",
help='Save the performance report as .html')

test_parser = subparsers.add_parser('test', description='Test model')
test_parser.add_argument('--model-epoch', type=int, default=1, help='Model epoch index')
test_parser.add_argument('--no-render', action='store_false', help='Do not render test')
test_parser.add_argument('--no-tearsheet', action='store_false', help='Do not render tearsheet')
test_parser.add_argument('--model-epoch', type=int, default=0, help='Model epoch index')
test_parser.add_argument('--no-render', dest="render_env", action="store_false",
help='Render the test environment')
test_parser.add_argument('--no-report', dest="render_report", action="store_false",
help='Render the performance report')
test_parser.add_argument('--save-report', dest="save_report", action="store_true",
help='Save the performance report as .html')

subparsers.add_parser('update-static-data', description='Update static data')

Expand Down

0 comments on commit a932a9d

Please sign in to comment.