diff --git a/tasks/rl_hopper.py b/tasks/rl_hopper.py index d719c93..fbdc4c9 100644 --- a/tasks/rl_hopper.py +++ b/tasks/rl_hopper.py @@ -261,10 +261,11 @@ def prepare_data_tuples(states, actions, rewards, num_layers, skip_steps): for j in range(1000): if random.random() <= epsilon or course == 0: selected_action = env.action_space.sample() + # quantize to -1 0 1 + selected_action = np.round(selected_action) else: a = model.react(alg.State(observation.data), stable_state) - # selected_action = np.clip(np.asarray(a.data), -1, 1) - selected_action = np.where(np.asarray(a.data) > 0, 1, -1) + selected_action = np.clip(np.asarray(a.data), -1, 1) next_observation, reward, terminated, truncated, info = env.step(selected_action) # check for nan