Skip to content

Commit

Permalink
potential fix for GravityChangeWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
becktepe committed Apr 9, 2024
1 parent 398a0da commit 4aa5bf1
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions adrl/continual_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class GravityChangeWrapper(Wrapper):
def __init__(self, env):
super().__init__(env)
self.n_steps = 0
self.n_total_steps = 0
self.n_switches = 0

def step(self, action):
Expand All @@ -16,10 +17,15 @@ def step(self, action):
truncated = True
return state, reward, terminated, truncated, info

def reset(self):
self.env.reset()
if self.n_steps // 10000 <= self.n_switches:
change_kind = np.random.choice(["flip", "random"])
def reset(self):
self.n_total_steps += self.n_steps
self.n_steps = 0

if self.n_total_steps // 10000 > self.n_switches:
# as gravity has to be in (-20. 0.01) flipping does not make sense
# change_kind = np.random.choice(["flip", "random"])
change_kind = "random"

if change_kind == "flip":
gravity = -self.env.context["GRAVITY_Y"]
else:
Expand Down

0 comments on commit 4aa5bf1

Please sign in to comment.