Skip to content

Commit

Permalink
Merge pull request octo-models#142 from rail-berkeley/add_gym_wrapper
Browse files Browse the repository at this point in the history
Adding Global Gym Wrappers
  • Loading branch information
dibyaghosh authored Dec 8, 2023
2 parents 86e6715 + 71816b2 commit 5965a96
Showing 1 changed file with 98 additions and 1 deletion.
99 changes: 98 additions & 1 deletion orca/utils/gym_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from collections import deque
import logging
from typing import Optional, Sequence, Tuple, Union

import gym
import gym.spaces
import jax
import numpy as np
import tensorflow as tf


def stack_and_pad(history: list, num_obs: int):
Expand Down Expand Up @@ -47,6 +50,57 @@ def listdict2dictlist(LD):
return {k: [dic[k] for dic in LD] for k in LD[0]}


def add_orca_env_wrappers(
env: gym.Env, config: dict, dataset_statistics: dict, **kwargs
):
"""Adds env wrappers for action normalization, multi-action
future prediction, image resizing, and history stacking.
Uses defaults from model config, but all can be overridden through kwargs.
Arguments:
env: gym Env
config: PretrainedModel.config
dataset_statistics: from PretrainedModel.load_dataset_statistics
# Additional (optional) kwargs
normalization_type: str for UnnormalizeActionProprio
exec_horizon: int for RHCWrapper
resize_size: None or tuple or list of tuples for ResizeImageWrapper
horizon: int for HistoryWrapper
"""
normalization_type = kwargs.get(
"normalization_type",
config["dataset_kwargs"]["common_dataset_kwargs"][
"action_proprio_normalization_type"
],
)

logging.info(
"Unnormalizing proprio and actions w/ statistics: ", dataset_statistics
)
env = UnnormalizeActionProprio(env, dataset_statistics, normalization_type)
exec_horizon = kwargs.get(
"exec_horizon", config["model"]["heads"]["action"]["kwargs"]["pred_horizon"]
)

logging.info("Running receding horizon control with exec_horizon: ", exec_horizon)
env = RHCWrapper(env, exec_horizon)
resize_size = kwargs.get(
"resize_size",
config["dataset_kwargs"]["frame_transform_kwargs"]["resize_size"],
)

logging.info("Resizing images w/ parameters", resize_size)
env = ResizeImageWrapper(env, resize_size)

horizon = kwargs.get("horizon", config["window_size"])
logging.info("Adding history of size: ", horizon)
env = HistoryWrapper(env, horizon)

logging.info("New observation space: ", env.observation_space)
return env


class HistoryWrapper(gym.Wrapper):
"""
Accumulates the observation history into `horizon` size chunks. If the length of the history
Expand Down Expand Up @@ -157,13 +211,56 @@ def step(self, actions):
return self.env.step(action)


class ResizeImageWrapper(gym.ObservationWrapper):
def __init__(
self,
env: gym.Env,
resize_size: Optional[Union[Tuple, Sequence[Tuple]]],
):
super().__init__(env)
assert isinstance(
self.observation_space, gym.spaces.Dict
), "Only Dict observation spaces are supported."
spaces = self.observation_space.spaces

if resize_size is None:
self.keys_to_resize = {}
elif isinstance(self.resize_size, tuple):
self.keys_to_resize = {k: resize_size for k in spaces if "image_" in k}
else:
self.keys_to_resize = {
f"image_{i}": resize_size[i] for i in range(len(resize_size))
}
logging.info(f"Resizing images: {self.keys_to_resize}")
for k, size in self.keys_to_resize.items():
spaces[k] = gym.spaces.Box(
low=0,
high=255,
shape=size + (3,),
dtype=np.uint8,
)
self.observation_space = gym.spaces.Dict(spaces)

def observation(self, observation):
for k, size in self.keys_to_resize.items():
image = tf.image.resize(
observation[k], size=size, method="lanczos3", antialias=True
)
image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8).numpy()
observation[k] = image
return observation


class UnnormalizeActionProprio(gym.ActionWrapper, gym.ObservationWrapper):
"""
Un-normalizes the action and proprio.
"""

def __init__(
self, env: gym.Env, action_proprio_metadata: dict, normalization_type: str
self,
env: gym.Env,
action_proprio_metadata: dict,
normalization_type: str,
):
self.action_proprio_metadata = jax.tree_map(
lambda x: np.array(x),
Expand Down

0 comments on commit 5965a96

Please sign in to comment.