From 80a55cf7b536355892906d89fa6c96fa8787d243 Mon Sep 17 00:00:00 2001 From: adamjking3 Date: Fri, 5 Jul 2019 17:11:20 -0700 Subject: [PATCH] Implement BaseRewardStrategy interface more clearly. --- lib/env/TradingEnv.py | 29 +++++++++------------------- lib/env/__init__.py | 3 +++ lib/env/reward/BaseRewardStrategy.py | 16 +++++++++++++++ lib/env/reward/IncrementalProfit.py | 25 ++++++++++++++++++++++++ lib/env/reward/__init__.py | 2 ++ optimize.py | 22 +++++++++++---------- 6 files changed, 67 insertions(+), 30 deletions(-) create mode 100644 lib/env/reward/BaseRewardStrategy.py create mode 100644 lib/env/reward/IncrementalProfit.py create mode 100644 lib/env/reward/__init__.py diff --git a/lib/env/TradingEnv.py b/lib/env/TradingEnv.py index 06c92c4..5fc0166 100644 --- a/lib/env/TradingEnv.py +++ b/lib/env/TradingEnv.py @@ -5,6 +5,7 @@ from gym import spaces from lib.env.render import TradingChart +from lib.env.reward import BaseRewardStrategy, IncrementalProfit from lib.data.providers import BaseDataProvider from lib.data.features.transform import max_min_normalize, log_and_difference from lib.util.logger import init_logger @@ -15,16 +16,16 @@ class TradingEnv(gym.Env): metadata = {'render.modes': ['human', 'system', 'none']} viewer = None - def __init__(self, data_provider: BaseDataProvider, initial_balance=10000, commission=0.0025, **kwargs): + def __init__(self, data_provider: BaseDataProvider, reward_strategy: BaseRewardStrategy = IncrementalProfit, initial_balance=10000, commission=0.0025, **kwargs): super(TradingEnv, self).__init__() self.logger = kwargs.get('logger', init_logger(__name__, show_debug=kwargs.get('show_debug', True))) self.data_provider = data_provider + self.reward_strategy = reward_strategy self.initial_balance = initial_balance self.commission = commission - self.reward_fn = kwargs.get('reward_fn', self._reward_incremental_profit) self.benchmarks = kwargs.get('benchmarks', []) self.enable_stationarization = kwargs.get('enable_stationarization', True) @@ -101,25 +102,13 @@ def _take_action(self, action): 'revenue_from_sold': revenue_from_sold, }, ignore_index=True) - def _reward_incremental_profit(self, observations, net_worths, account_history, last_bought, last_sold, current_price): - prev_balance = account_history['balance'].values[-2] - curr_balance = account_history['balance'].values[-1] - reward = 0 - - if curr_balance > prev_balance: - reward = net_worths[-1] - net_worths[last_bought] - elif curr_balance < prev_balance: - reward = observations['Close'].values[last_sold] - current_price - - return reward - def _reward(self): - reward = self.reward_fn(observations=self.observations, - net_worths=self.net_worths, - account_history=self.account_history, - last_bought=self.last_bought, - last_sold=self.last_sold, - current_price=self._current_price()) + reward = self.reward_strategy.get_reward(observations=self.observations, + net_worths=self.net_worths, + account_history=self.account_history, + last_bought=self.last_bought, + last_sold=self.last_sold, + current_price=self._current_price()) return reward if np.isfinite(reward) else 0 diff --git a/lib/env/__init__.py b/lib/env/__init__.py index 01c5c0f..55ba2d1 100644 --- a/lib/env/__init__.py +++ b/lib/env/__init__.py @@ -1 +1,4 @@ from lib.env.TradingEnv import TradingEnv +from lib.env.render.TradingChart import TradingChart + +import lib.env.reward as reward diff --git a/lib/env/reward/BaseRewardStrategy.py b/lib/env/reward/BaseRewardStrategy.py new file mode 100644 index 0000000..741d416 --- /dev/null +++ b/lib/env/reward/BaseRewardStrategy.py @@ -0,0 +1,16 @@ +import pandas as pd + +from abc import ABCMeta, abstractmethod +from typing import List + + +class BaseRewardStrategy(object, metaclass=ABCMeta): + @abstractmethod + @staticmethod + def get_reward(observations: pd.DataFrame, + account_history: pd.DataFrame, + net_worths: List[float], + last_bought: int, + last_sold: int, + current_price: float) -> float: + raise NotImplementedError() diff --git a/lib/env/reward/IncrementalProfit.py b/lib/env/reward/IncrementalProfit.py new file mode 100644 index 0000000..880e440 --- /dev/null +++ b/lib/env/reward/IncrementalProfit.py @@ -0,0 +1,25 @@ +import pandas as pd + +from typing import List + +from lib.env.reward import BaseRewardStrategy + + +class IncrementalProfit(BaseRewardStrategy): + @staticmethod + def get_reward(observations: pd.DataFrame, + account_history: pd.DataFrame, + net_worths: List[float], + last_bought: int, + last_sold: int, + current_price: float): + prev_balance = account_history['balance'].values[-2] + curr_balance = account_history['balance'].values[-1] + reward = 0 + + if curr_balance > prev_balance: + reward = net_worths[-1] - net_worths[last_bought] + elif curr_balance < prev_balance: + reward = observations['Close'].values[last_sold] - current_price + + return reward diff --git a/lib/env/reward/__init__.py b/lib/env/reward/__init__.py new file mode 100644 index 0000000..6a95835 --- /dev/null +++ b/lib/env/reward/__init__.py @@ -0,0 +1,2 @@ +from lib.env.reward.IncrementalProfit import IncrementalProfit +from lib.env.reward.BaseRewardStrategy import BaseRewardStrategy diff --git a/optimize.py b/optimize.py index 6548ebf..01f744c 100644 --- a/optimize.py +++ b/optimize.py @@ -12,19 +12,21 @@ def optimize_code(params): if __name__ == '__main__': - n_process = multiprocessing.cpu_count() - params = {} + n_process = multiprocessing.cpu_count() - 4 + params = { + 'n_envs': n_process + } - process = [] - for i in range(n_process): - process.append(multiprocessing.Process(target=optimize_code, args=(params,))) + # process = [] + # for i in range(n_process): + # process.append(multiprocessing.Process(target=optimize_code, args=(params,))) - for p in process: - p.start() + # for p in process: + # p.start() - for p in process: - p.join() + # for p in process: + # p.join() trader = RLTrader(**params) - trader.train(test_trained_model=True, render_trained_model=True) + trader.train(test_trained_model=True, render_trained_model=False)