-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdelayed_mcesp.py
56 lines (48 loc) · 2.08 KB
/
delayed_mcesp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import numpy as np
"""
MCESP for Game-Delayed Reinforcements
"""
class MCESP_D:
def __init__(self, observations, field = np.zeros((1,1))):
"""
Constructor for MCESP-D. Field integration currently stubbed.
Parameters
----------
observations : int
The allowed specificity of stress levels
field : array, shape (H, W)
A NumPy grayscale image
"""
self.actions = 2
self.observations = observations
self.q_table = np.ones((self.observations,self.actions))
self.c_table = np.zeros((self.observations,self.actions))
self.set_prior(field)
def set_prior(self,field):
"""
Set the initial observation discretization to the dimentionality of observations.
Initially set discretization factor to uniform. Set discretization learning rate to 1.
"""
self.observation_thresholds = [i/self.observations for i in range(0,self.observations)]
self.observation_samples = 1
# TODO: For use after integrating image processing with MCESP for Game-Delayed Reinforcements
# self.norm = field.max()
def update_reward(self, observation, action, reward):
"""
Update the Q-table when a delayed reward is received from a subsequent layer.
"""
self.q_table[observation,action] = (1 - self.count(observation,action)) * self.q_table[observation,action] + self.count(observation,action) * reward # Canonical Q-update
self.increment_count(observation,action)
def count(self,observation, action):
"""
Q-learning learning schedule.
"""
return(1/(1+self.c_table[observation,action]))
def increment_count(self,observation,action):
self.c_table[observation,action] += 1
def act(self,observation):
"""
Return the current learned max action for this layer. If there's a tie, pick randomly.
"""
maximum_actions = np.argwhere(self.q_table[observation] == np.amax(self.q_table[observation])).flatten()
return(np.random.choice(maximum_actions))