Skip to content

Commit

Permalink
Implement BaseRewardStrategy interface more clearly.
Browse files Browse the repository at this point in the history
  • Loading branch information
notadamking committed Jul 6, 2019
1 parent 36d3717 commit 80a55cf
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 30 deletions.
29 changes: 9 additions & 20 deletions lib/env/TradingEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from gym import spaces

from lib.env.render import TradingChart
from lib.env.reward import BaseRewardStrategy, IncrementalProfit
from lib.data.providers import BaseDataProvider
from lib.data.features.transform import max_min_normalize, log_and_difference
from lib.util.logger import init_logger
Expand All @@ -15,16 +16,16 @@ class TradingEnv(gym.Env):
metadata = {'render.modes': ['human', 'system', 'none']}
viewer = None

def __init__(self, data_provider: BaseDataProvider, initial_balance=10000, commission=0.0025, **kwargs):
def __init__(self, data_provider: BaseDataProvider, reward_strategy: BaseRewardStrategy = IncrementalProfit, initial_balance=10000, commission=0.0025, **kwargs):
super(TradingEnv, self).__init__()

self.logger = kwargs.get('logger', init_logger(__name__, show_debug=kwargs.get('show_debug', True)))

self.data_provider = data_provider
self.reward_strategy = reward_strategy
self.initial_balance = initial_balance
self.commission = commission

self.reward_fn = kwargs.get('reward_fn', self._reward_incremental_profit)
self.benchmarks = kwargs.get('benchmarks', [])
self.enable_stationarization = kwargs.get('enable_stationarization', True)

Expand Down Expand Up @@ -101,25 +102,13 @@ def _take_action(self, action):
'revenue_from_sold': revenue_from_sold,
}, ignore_index=True)

def _reward_incremental_profit(self, observations, net_worths, account_history, last_bought, last_sold, current_price):
prev_balance = account_history['balance'].values[-2]
curr_balance = account_history['balance'].values[-1]
reward = 0

if curr_balance > prev_balance:
reward = net_worths[-1] - net_worths[last_bought]
elif curr_balance < prev_balance:
reward = observations['Close'].values[last_sold] - current_price

return reward

def _reward(self):
reward = self.reward_fn(observations=self.observations,
net_worths=self.net_worths,
account_history=self.account_history,
last_bought=self.last_bought,
last_sold=self.last_sold,
current_price=self._current_price())
reward = self.reward_strategy.get_reward(observations=self.observations,
net_worths=self.net_worths,
account_history=self.account_history,
last_bought=self.last_bought,
last_sold=self.last_sold,
current_price=self._current_price())

return reward if np.isfinite(reward) else 0

Expand Down
3 changes: 3 additions & 0 deletions lib/env/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from lib.env.TradingEnv import TradingEnv
from lib.env.render.TradingChart import TradingChart

import lib.env.reward as reward
16 changes: 16 additions & 0 deletions lib/env/reward/BaseRewardStrategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pandas as pd

from abc import ABCMeta, abstractmethod
from typing import List


class BaseRewardStrategy(object, metaclass=ABCMeta):
@abstractmethod
@staticmethod
def get_reward(observations: pd.DataFrame,
account_history: pd.DataFrame,
net_worths: List[float],
last_bought: int,
last_sold: int,
current_price: float) -> float:
raise NotImplementedError()
25 changes: 25 additions & 0 deletions lib/env/reward/IncrementalProfit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pandas as pd

from typing import List

from lib.env.reward import BaseRewardStrategy


class IncrementalProfit(BaseRewardStrategy):
@staticmethod
def get_reward(observations: pd.DataFrame,
account_history: pd.DataFrame,
net_worths: List[float],
last_bought: int,
last_sold: int,
current_price: float):
prev_balance = account_history['balance'].values[-2]
curr_balance = account_history['balance'].values[-1]
reward = 0

if curr_balance > prev_balance:
reward = net_worths[-1] - net_worths[last_bought]
elif curr_balance < prev_balance:
reward = observations['Close'].values[last_sold] - current_price

return reward
2 changes: 2 additions & 0 deletions lib/env/reward/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from lib.env.reward.IncrementalProfit import IncrementalProfit
from lib.env.reward.BaseRewardStrategy import BaseRewardStrategy
22 changes: 12 additions & 10 deletions optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@ def optimize_code(params):


if __name__ == '__main__':
n_process = multiprocessing.cpu_count()
params = {}
n_process = multiprocessing.cpu_count() - 4
params = {
'n_envs': n_process
}

process = []
for i in range(n_process):
process.append(multiprocessing.Process(target=optimize_code, args=(params,)))
# process = []
# for i in range(n_process):
# process.append(multiprocessing.Process(target=optimize_code, args=(params,)))

for p in process:
p.start()
# for p in process:
# p.start()

for p in process:
p.join()
# for p in process:
# p.join()

trader = RLTrader(**params)

trader.train(test_trained_model=True, render_trained_model=True)
trader.train(test_trained_model=True, render_trained_model=False)

0 comments on commit 80a55cf

Please sign in to comment.