Skip to content

Commit

Permalink
Update syntax styling and improve parameter naming
Browse files Browse the repository at this point in the history
  • Loading branch information
notadamking committed Jul 10, 2019
1 parent 70cc6c3 commit 80fa433
Showing 1 changed file with 37 additions and 16 deletions.
53 changes: 37 additions & 16 deletions lib/RLTrader.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,12 @@ def optimize_params(self, trial, n_prune_evals_per_trial: int = 2, n_tests_per_e
validation_env = SubprocVecEnv([make_env(validation_provider, i) for i in range(1)])

model_params = self.optimize_agent_params(trial)
model = self.Model(self.Policy, train_env, verbose=self.model_verbose, nminibatches=1,
tensorboard_log=self.tensorboard_path, **model_params)
model = self.Model(self.Policy,
train_env,
verbose=self.model_verbose,
nminibatches=1,
tensorboard_log=self.tensorboard_path,
**model_params)

last_reward = -np.finfo(np.float16).max
n_steps_per_eval = int(len(train_provider.data_frame) / n_prune_evals_per_trial)
Expand Down Expand Up @@ -181,9 +185,9 @@ def optimize_params(self, trial, n_prune_evals_per_trial: int = 2, n_tests_per_e

return -1 * last_reward

def optimize(self, n_trials: int = 20, *optimize_params):
def optimize(self, n_trials: int = 20, **optimize_params):
try:
self.optuna_study.optimize(self.optimize_params, n_trials=n_trials, n_jobs=1, *optimize_params)
self.optuna_study.optimize(self.optimize_params, n_trials=n_trials, n_jobs=1, **optimize_params)
except KeyboardInterrupt:
pass

Expand All @@ -197,7 +201,13 @@ def optimize(self, n_trials: int = 20, *optimize_params):

return self.optuna_study.trials_dataframe()

def train(self, n_epochs: int = 10, save_every: int = 1, test_trained_model: bool = True, render_trained_model: bool = False, save_results: bool = True):
def train(self,
n_epochs: int = 10,
save_every: int = 1,
test_trained_model: bool = True,
render_test_env: bool = False,
render_report: bool = True,
save_report: bool = False):
train_provider, test_provider = self.data_provider.split_data_train_test(self.train_split_percentage)

del test_provider
Expand All @@ -206,8 +216,12 @@ def train(self, n_epochs: int = 10, save_every: int = 1, test_trained_model: boo

model_params = self.get_model_params()

model = self.Model(self.Policy, train_env, verbose=self.model_verbose, nminibatches=self.n_minibatches,
tensorboard_log=self.tensorboard_path, **model_params)
model = self.Model(self.Policy,
train_env,
verbose=self.model_verbose,
nminibatches=self.n_minibatches,
tensorboard_log=self.tensorboard_path,
**model_params)

self.logger.info(f'Training for {n_epochs} epochs')

Expand All @@ -223,11 +237,14 @@ def train(self, n_epochs: int = 10, save_every: int = 1, test_trained_model: boo
model.save(model_path)

if test_trained_model:
self.test(model_epoch, should_render=render_trained_model, render_tearsheet=False, save_tearsheet=save_results)
self.test(model_epoch,
render_env=render_test_env,
render_report=render_report,
save_report=save_report)

self.logger.info(f'Trained {n_epochs} models')

def test(self, model_epoch: int = 0, should_render: bool = True, render_tearsheet: bool = True, save_tearsheet: bool = False):
def test(self, model_epoch: int = 0, render_env: bool = True, render_report: bool = True, save_report: bool = False):
train_provider, test_provider = self.data_provider.split_data_train_test(self.train_split_percentage)

del train_provider
Expand Down Expand Up @@ -255,20 +272,24 @@ def test(self, model_epoch: int = 0, should_render: bool = True, render_tearshee

rewards.append(reward)

if should_render:
if render_env:
test_env.render(mode='human')

if done:
net_worths = pd.DataFrame(
{'Date': info[0]['timestamps'],
'Balance': info[0]['networths'],
})
net_worths = pd.DataFrame({
'Date': info[0]['timestamps'],
'Balance': info[0]['networths'],
})

net_worths.set_index('Date', drop=True, inplace=True)
returns = net_worths.pct_change()[1:]
if(render_tearsheet):

if render_report:
qs.plots.snapshot(returns.Balance, title='RL Trader Performance')
if(save_tearsheet):

if save_report:
reports_path = path.join('data', 'reports', f'{self.study_name}__{model_epoch}.html')
qs.reports.html(returns.Balance, file=reports_path)

self.logger.info(
f'Finished testing model ({self.study_name}__{model_epoch}): ${"{:.2f}".format(np.sum(rewards))}')

0 comments on commit 80fa433

Please sign in to comment.