-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Differential semi-gradient Sarsa: Access control task done
Note that avg is about 2.6 not 2.31 as stated in the book and algorithm fails when doing 2mil iterations producing different policy. Not sure what may be the problem maybe the parameters should be different
- Loading branch information
Showing
3 changed files
with
240 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import numpy as np | ||
from gym import Env | ||
import plotly.offline as py | ||
import plotly.graph_objs as go | ||
from plotly import tools | ||
|
||
from envs.AcessControlQueueEnv import AccessControlQueueTimeLimit, AccessControlQueue | ||
from utils import Algorithm, randargmax, generate_episode, epsilon_probs, TilingValueFunction | ||
|
||
np.random.seed(7) | ||
|
||
|
||
class ValueFunction(TilingValueFunction): | ||
|
||
def __init__(self, n_tilings, max_size, n_priorities, n_servers): | ||
super().__init__(n_tilings, max_size) | ||
self.n_priorities = n_priorities - 1 | ||
self.n_servers = n_servers | ||
|
||
def scaled_values(self, state): | ||
priority, free_servers = state | ||
priority_scale = self.n_tilings / self.n_priorities | ||
server_scale = self.n_tilings / self.n_servers | ||
return [priority_scale * priority, server_scale * free_servers] | ||
|
||
|
||
class DifferentialSemiGradientSarsa(Algorithm): | ||
def __init__(self, env: Env, value_function, alpha=0.01, beta=0.01, epsilon=0.1): | ||
self.alpha = alpha | ||
self.beta = beta | ||
self.epsilon = epsilon | ||
self.actions = np.arange(env.action_space.n) | ||
self.value_function = value_function | ||
self._reset() | ||
|
||
def action(self, state): | ||
if self.next_action is not None: | ||
return self.next_action | ||
else: | ||
return self._action(state) | ||
|
||
def _action(self, state): | ||
_, free_servers = state | ||
if free_servers == 0: | ||
return AccessControlQueue.ACTION_REJECT | ||
greedy = self._greedy_action(state) | ||
probs = epsilon_probs(greedy, self.actions, self.epsilon) | ||
return np.random.choice(self.actions, p=probs) | ||
|
||
def _greedy_action(self, state): | ||
return randargmax(np.array([self.value_function.estimated(state, action) for action in self.actions])) | ||
|
||
def _reset(self): | ||
self.average_reward = 0 | ||
self.next_action = None | ||
|
||
def on_new_state(self, state, action, reward, next_state, done): | ||
self.next_action = self._action(next_state) | ||
q_next = self.value_function.estimated(next_state, self.next_action) | ||
q = self.value_function.estimated(state, action) | ||
delta = reward - self.average_reward + q_next - q | ||
self.average_reward += self.beta * delta | ||
print('Average reward:', self.average_reward) | ||
self.value_function[state, action] += self.alpha * delta | ||
if done: | ||
self._reset() | ||
|
||
|
||
if __name__ == '__main__': | ||
n_servers = 10 | ||
env = AccessControlQueueTimeLimit(max_episode_steps=int(1e6), free_prob=0.06, n_servers=n_servers) | ||
value_function = ValueFunction(8, 2048, len(AccessControlQueue.PRIORITIES), n_servers) | ||
algorithm = DifferentialSemiGradientSarsa(env, value_function, alpha=0.01 / value_function.n_tilings) | ||
generate_episode(env, algorithm, print_step=True) | ||
|
||
policy = value_function.to_policy(algorithm.actions, AccessControlQueue.PRIORITIES, np.arange(n_servers + 1)) | ||
values = value_function.to_value(algorithm.actions, AccessControlQueue.PRIORITIES, np.arange(n_servers + 1)) | ||
|
||
fig = tools.make_subplots(rows=1, cols=2) | ||
fig.append_trace(go.Heatmap(z=policy, | ||
x=np.arange(n_servers + 1), | ||
y=AccessControlQueue.REWARDS, | ||
name='Policy'), 1, 1) | ||
for i, row in enumerate(values): | ||
row[0] = value_function.estimated((i, 0), AccessControlQueue.ACTION_REJECT) | ||
fig.append_trace(go.Scatter(y=row, name='n={}'.format(AccessControlQueue.REWARDS[i])), 1, 2) | ||
fig.layout.yaxis1.autorange = 'reversed' | ||
py.plot(fig) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from gym import Env | ||
from gym.spaces import Discrete, Tuple | ||
import numpy as np | ||
from gym.wrappers import TimeLimit | ||
|
||
|
||
class AccessControlActionSpace(Discrete): | ||
def __init__(self, env, n): | ||
super().__init__(n) | ||
self.env = env | ||
|
||
def sample(self): | ||
if self.env.free_servers == 0: | ||
return AccessControlQueue.ACTION_REJECT | ||
else: | ||
return super().sample() | ||
|
||
|
||
class AccessControlQueue(Env): | ||
metadata = {'render.modes': ['human']} | ||
PRIORITIES = np.arange(4) | ||
REWARDS = [1, 2, 4, 8] | ||
ACTION_REJECT = 0 | ||
ACTION_ACCEPT = 1 | ||
|
||
def __init__(self, n_servers=10, free_prob=0.06): | ||
self.n_servers = n_servers | ||
self.free_prob = free_prob | ||
self.action_space = AccessControlActionSpace(self, 2) | ||
self.observation_space = Tuple(( | ||
Discrete(len(AccessControlQueue.PRIORITIES)), | ||
Discrete(self.n_servers + 1) | ||
)) | ||
self._reset() | ||
|
||
def _step(self, action): | ||
reward = 0 | ||
if action == AccessControlQueue.ACTION_ACCEPT: | ||
self._try_use_server() | ||
reward = AccessControlQueue.REWARDS[self.current_priority] | ||
# Next customer | ||
self.current_priority = self._pop_customer() | ||
self._try_free_servers() | ||
return self._obs(), reward, False, None | ||
|
||
def _reset(self): | ||
self.free_servers = self.n_servers | ||
self.current_priority = self._pop_customer() | ||
return self._obs() | ||
|
||
def _render(self, mode='human', close=False): | ||
print('*************************') | ||
print('Current ({}): {}'.format(self.current_priority, AccessControlQueue.REWARDS[self.current_priority])) | ||
print('Free servers:', self.free_servers) | ||
|
||
def _pop_customer(self): | ||
return np.random.choice(AccessControlQueue.PRIORITIES) | ||
|
||
def _try_free_servers(self): | ||
busy = self.n_servers - self.free_servers | ||
self.free_servers += np.random.binomial(busy, self.free_prob) | ||
|
||
def _try_use_server(self): | ||
if self.free_servers == 0: | ||
raise ValueError('Cannot accept with all servers busy') | ||
self.free_servers -= 1 | ||
|
||
def _obs(self): | ||
return self.current_priority, self.free_servers | ||
|
||
|
||
class AccessControlQueueTimeLimit(TimeLimit): | ||
def __init__(self, max_episode_steps, n_servers=10, free_prob=0.06): | ||
super().__init__(AccessControlQueue(n_servers=n_servers, free_prob=free_prob), | ||
max_episode_steps=max_episode_steps) | ||
|
||
|
||
if __name__ == '__main__': | ||
env = AccessControlQueue() | ||
obs = env.reset() | ||
for i in range(100): | ||
env.render() | ||
action = env.action_space.sample() | ||
obs, reward, _, _ = env.step(action) | ||
print('Action:', 'ACCEPT' if action == AccessControlQueue.ACTION_ACCEPT else 'REJECT') | ||
print('Reward:', reward) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters