From 9d24e1f3d64302d864e09977edb8e2f465731c07 Mon Sep 17 00:00:00 2001 From: Andres Kull Date: Wed, 17 Jul 2019 11:21:01 +0300 Subject: [PATCH] hyperparameter defaults updated, added vf_coef and max_grad_norm --- lib/RLTrader.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/lib/RLTrader.py b/lib/RLTrader.py index 4a8ccc0..d37899e 100644 --- a/lib/RLTrader.py +++ b/lib/RLTrader.py @@ -112,6 +112,8 @@ def get_model_params(self): 'cliprange': params['cliprange'], 'noptepochs': int(params['noptepochs']), 'lam': params['lam'], + 'vf_coef': params['vf_coef'], + 'max_grad_norm': params['max_grad_norm'] } def optimize_agent_params(self, trial): @@ -119,13 +121,17 @@ def optimize_agent_params(self, trial): return {'learning_rate': trial.suggest_loguniform('learning_rate', 1e-5, 1.)} return { - 'n_steps': int(trial.suggest_loguniform('n_steps', 16, 2048)), - 'gamma': trial.suggest_loguniform('gamma', 0.9, 0.9999), - 'learning_rate': trial.suggest_loguniform('learning_rate', 1e-5, 1.), - 'ent_coef': trial.suggest_loguniform('ent_coef', 1e-8, 1e-1), - 'cliprange': trial.suggest_uniform('cliprange', 0.1, 0.4), - 'noptepochs': int(trial.suggest_loguniform('noptepochs', 1, 48)), - 'lam': trial.suggest_uniform('lam', 0.8, 1.) + # PPO2 hyperparameter ranges + # https://medium.com/aureliantactics/ppo-hyperparameters-and-ranges-6fc2d29bccbe + 'n_steps': int(trial.suggest_loguniform('n_steps', 32, 5000)), + 'gamma': trial.suggest_uniform('gamma', 0.8, 0.9997), + 'learning_rate': trial.suggest_loguniform('learning_rate', 5e-6, 0.003), + 'ent_coef': trial.suggest_uniform('ent_coef', 0, 1e-2), + 'cliprange': trial.suggest_uniform('cliprange', 0.1, 0.3), + 'noptepochs': int(trial.suggest_uniform('noptepochs', 3, 30)), + 'lam': trial.suggest_uniform('lam', 0.9, 1.), + 'vf_coef': trial.suggest_uniform('vf_coef', 0.5, 1.), + 'max_grad_norm': trial.suggest_uniform('max_grad_norm', 0.4, 0.6), } def optimize_params(self, trial, n_prune_evals_per_trial: int = 2, n_tests_per_eval: int = 1): @@ -289,7 +295,11 @@ def test(self, model_epoch: int = 0, render_env: bool = True, render_report: boo if save_report: reports_path = path.join('data', 'reports', f'{self.study_name}__{model_epoch}.html') + # try: qs.reports.html(returns.Balance, file=reports_path) + # except Exception as err: + # print(err) + # pass self.logger.info( f'Finished testing model ({self.study_name}__{model_epoch}): ${"{:.2f}".format(np.sum(rewards))}')