Skip to content

Commit

Permalink
Refactoring: Move common methods and classes to separate package
Browse files Browse the repository at this point in the history
  • Loading branch information
adik993 committed Dec 14, 2017
1 parent 3fea3c1 commit c6a93a4
Show file tree
Hide file tree
Showing 15 changed files with 78 additions and 92 deletions.
3 changes: 1 addition & 2 deletions cliffwalking.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from gym import Env

from envs.CliffWalkingEnv import CliffWalking
from windy_gridworld import Sarsa, generate_episode
import numpy as np
from log import make_logger
from windy_gridworld import Sarsa, generate_episode

log = make_logger(__name__)

Expand Down
28 changes: 1 addition & 27 deletions double_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,7 @@
import plotly.offline as py
import plotly.graph_objs as go


def randomargmax(d, key=None):
k_max = max(d, key=key)
return np.random.choice([k for k, v in d.items() if d[k_max] == v])


def epsilon_prob(greedy, action, n_actions, epsilon):
if greedy == action:
return epsilon_greedy_prob(n_actions, epsilon)
else:
return epsilon_explore_prob(n_actions, epsilon)


def epsilon_greedy_prob(n_actions, epsilon):
return 1 - epsilon + epsilon / n_actions


def epsilon_explore_prob(n_actions, epsilon):
return epsilon / n_actions


class Algorithm:
def action(self, state):
raise NotImplementedError()

def on_new_state(self, state, action, reward, next_state, done):
raise NotImplementedError()
from utils import epsilon_prob, randomargmax, Algorithm


class QLearning(Algorithm):
Expand Down
9 changes: 3 additions & 6 deletions dyna_q.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import random

from double_q_learning import epsilon_prob
from envs.MazeEnv import BasicMaze, Maze, MazeShortLong, MazeLongShort
from n_step_sarsa import Algorithm
import numpy as np
import plotly.graph_objs as go
import plotly.offline as py


def randomargmax(a: np.ndarray):
return np.random.choice(np.flatnonzero(a == a.max()))
from double_q_learning import epsilon_prob
from envs.MazeEnv import Maze, MazeLongShort
from utils import randomargmax, Algorithm


class DynaQ(Algorithm):
Expand Down
8 changes: 2 additions & 6 deletions envs/CliffWalkingEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
import numpy as np


def minmax(value, low, high):
return max(min(value, high), low)


class CliffWalking(Env):
metadata = {'render.modes': ['human']}
ACTION_UP = 0
Expand Down Expand Up @@ -55,8 +51,8 @@ def _reward(self, felt):
return -100 if felt else -1

def _move(self, by):
axis0 = minmax(self.position[0] + by[0], 0, self.world.shape[0] - 1)
axis1 = minmax(self.position[1] + by[1], 0, self.world.shape[1] - 1)
axis0 = np.clip(self.position[0] + by[0], 0, self.world.shape[0] - 1)
axis1 = np.clip(self.position[1] + by[1], 0, self.world.shape[1] - 1)
felt = False
if self.world[axis0, axis1] == CliffWalking.CLIFF:
felt = True
Expand Down
6 changes: 2 additions & 4 deletions envs/GridWorldEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import numpy as np
from gym.spaces import Tuple, Discrete

from envs.WindyGridWorldEnv import minmax


class GridWorld(Env):
metadata = {'render.modes': ['human']}
Expand Down Expand Up @@ -48,8 +46,8 @@ def _step(self, action):
return self._obs(), -1, done, self.world

def _move(self, move):
axis0 = minmax(self.position[0] + move[0], 0, self.world.shape[0] - 1)
axis1 = minmax(self.position[1] + move[1], 0, self.world.shape[1] - 1)
axis0 = np.clip(self.position[0] + move[0], 0, self.world.shape[0] - 1)
axis1 = np.clip(self.position[1] + move[1], 0, self.world.shape[1] - 1)
self.position = (axis0, axis1)

def _reset(self):
Expand Down
7 changes: 3 additions & 4 deletions envs/MazeEnv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from gym import Env
from gym.spaces import Tuple, Discrete

from envs.CliffWalkingEnv import minmax
import numpy as np
from envs.GridWorldEnv import GridWorld


Expand All @@ -18,8 +17,8 @@ def is_wall(self, position):
return self.world[position] == Maze.WALL

def _move(self, move):
axis0 = minmax(self.position[0] + move[0], 0, self.world.shape[0] - 1)
axis1 = minmax(self.position[1] + move[1], 0, self.world.shape[1] - 1)
axis0 = np.clip(self.position[0] + move[0], 0, self.world.shape[0] - 1)
axis1 = np.clip(self.position[1] + move[1], 0, self.world.shape[1] - 1)
if not self.is_wall((axis0, axis1)):
self.position = (axis0, axis1)

Expand Down
4 changes: 1 addition & 3 deletions envs/RandomWalkEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from gym.spaces import Discrete
import numpy as np

from envs.WindyGridWorldEnv import minmax


class RandomWalk(Env):
metadata = {'render.modes': ['human']}
Expand Down Expand Up @@ -38,7 +36,7 @@ def _step(self, action):
self.position -= step
else:
self.position += step
self.position = minmax(self.position, 0, len(self.states) - 1)
self.position = np.clip(self.position, 0, len(self.states) - 1)

done = self.position == 0 or self.position == len(self.states) - 1
reward = self.states[self.position]
Expand Down
8 changes: 2 additions & 6 deletions envs/WindyGridWorldEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ def inc(tuple, val):
return tuple[0] + val, tuple[1] + val


def minmax(value, low, high):
return max(min(value, high), low)


class WindyGridWorld(Env):
metadata = {'render.modes': ['human']}
ACTION_UP = 0
Expand Down Expand Up @@ -71,9 +67,9 @@ def _step(self, action):

def _move(self, by):
wind = self._get_wind(self.position[1])
axis1 = minmax(self.position[1] + by[1], 0, self.size[1] - 1)
axis1 = np.clip(self.position[1] + by[1], 0, self.size[1] - 1)
axis0 = self.position[0] + by[0] - wind
axis0 = minmax(axis0, 0, self.size[0] - 1)
axis0 = np.clip(axis0, 0, self.size[0] - 1)
self.position = axis0, axis1

def _get_wind(self, axis1):
Expand Down
14 changes: 3 additions & 11 deletions gradient_methods_random_walk.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from gym import Env

from double_q_learning import Algorithm
from envs.CliffWalkingEnv import minmax
from envs.RandomWalkEnv import RandomWalk
import numpy as np
import plotly.offline as py
import plotly.graph_objs as go

from utils import EpisodeAlgorithm, Algorithm

N_AGGREGATE = 100
N_STATES = 1000
MAX_STEP = 100
Expand All @@ -21,7 +21,7 @@ def find_true_values():
for action in [-1, 1]:
for step in range(1, MAX_STEP + 1):
step *= action
next_state = minmax(state + step, 0, N_STATES + 1)
next_state = np.clip(state + step, 0, N_STATES + 1)
prob = 1 / (MAX_STEP * 2)
new[state] += prob * (0 + new[next_state])
error = np.abs(np.sum(old - new))
Expand All @@ -40,14 +40,6 @@ def __init__(self, state, reward):
self.reward = reward


class EpisodeAlgorithm:
def action(self, state):
raise NotImplementedError()

def on_new_episode(self, history):
raise NotImplementedError()


class ValueFunction:
def __init__(self, shape, aggregation=N_AGGREGATE):
self.value = np.zeros([s // aggregation for s in shape])
Expand Down
8 changes: 1 addition & 7 deletions n_step_sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@
import plotly.offline as py
import plotly.graph_objs as go


class Algorithm:
def action(self, state):
raise NotImplementedError()

def on_new_state(self, state, action, reward, next_state, done):
raise NotImplementedError()
from utils import Algorithm


class NStepSarsa(Algorithm):
Expand Down
15 changes: 4 additions & 11 deletions n_step_td_random_walk.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
import sys
from collections import deque

from envs.RandomWalkEnv import RandomWalk, Env
from randomwalk import rmse
import numpy as np
import plotly.offline as py
import plotly.graph_objs as go
import plotly.offline as py

from envs.RandomWalkEnv import RandomWalk
from randomwalk import rmse
from utils import Algorithm

TRUE_VALUES = np.arange(-20, 22, 2) / 20.0

class Algorithm:
def action(self, state):
raise NotImplementedError()

def on_new_state(self, state, action, reward, next_state, done):
raise NotImplementedError()


class NStepTD(Algorithm):
def __init__(self, env: RandomWalk, n, alpha=0.1, gamma=1):
Expand Down
3 changes: 2 additions & 1 deletion n_step_tree_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from double_q_learning import epsilon_prob
from envs.GridWorldEnv import GridWorld
from n_step_sarsa import Algorithm, perform_algo_eval, NStepSarsa
from n_step_sarsa import perform_algo_eval
from utils import Algorithm


class Entry:
Expand Down
8 changes: 5 additions & 3 deletions random_walk_td_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import plotly.offline as py
import plotly.graph_objs as go

from utils import Algorithm


class RandomPolicy:
def __init__(self, env: Env):
Expand All @@ -26,7 +28,7 @@ def __getitem__(self, item):
return self.actions[self.index]


class TD:
class TD(Algorithm):
def __init__(self, env: Env, policy, alpha=0.1, gamma=1, lam=0.9):
self.alpha = alpha
self.gamma = gamma
Expand All @@ -41,7 +43,7 @@ def trace(self, state):
def action(self, state):
return self.policy[state]

def on_new_state(self, state, reward, next_state, done):
def on_new_state(self, state, action, reward, next_state, done):
v = self.values[state]
v_next = self.values[next_state]
delta = reward + self.gamma * v_next - v
Expand All @@ -61,7 +63,7 @@ def generate_episode(env: Env, algorithm: TD):
prev_obs = obs
action = algorithm.action(prev_obs)
obs, reward, done, aux = env.step(action)
algorithm.on_new_state(prev_obs, reward, obs, done)
algorithm.on_new_state(prev_obs, action, reward, obs, done)


def perform_lam_test(env, lams, alphas, n_avg=1, n=10):
Expand Down
45 changes: 45 additions & 0 deletions utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
from math import ceil

def randomargmax(d, key=None):
k_max = max(d, key=key)
return np.random.choice([k for k, v in d.items() if d[k_max] == v])


def randargmax(b, **kw):
""" a random tie-breaking argmax"""
return np.argmax(np.random.random(b.shape) * (b == b.max()), **kw)


def epsilon_prob(greedy, action, n_actions, epsilon):
if greedy == action:
return epsilon_greedy_prob(n_actions, epsilon)
else:
return epsilon_explore_prob(n_actions, epsilon)


def epsilon_greedy_prob(n_actions, epsilon):
return 1 - epsilon + epsilon / n_actions


def epsilon_explore_prob(n_actions, epsilon):
return epsilon / n_actions


def calc_batch_size(size, n_batches, batch_idx):
return max(0, min(size - batch_idx * ceil(size / n_batches), ceil(size / n_batches)))

class Algorithm:
def action(self, state):
raise NotImplementedError()

def on_new_state(self, state, action, reward, next_state, done):
raise NotImplementedError()


class EpisodeAlgorithm:
def action(self, state):
raise NotImplementedError()

def on_new_episode(self, history):
raise NotImplementedError()
4 changes: 3 additions & 1 deletion windy_gridworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from log import make_logger
import numpy as np

from utils import Algorithm

log = make_logger(__name__)


class Sarsa:
class Sarsa(Algorithm):
def __init__(self, env: Env, alpha=0.5, gamma=1, epsilon=0.1):
self.alpha = alpha
self.gamma = gamma
Expand Down

0 comments on commit c6a93a4

Please sign in to comment.