diff --git a/config/train_dp.json b/config/train_dp.json deleted file mode 100644 index 519e275..0000000 --- a/config/train_dp.json +++ /dev/null @@ -1,129 +0,0 @@ -{ - "dataset": { - "repo_id": "danielsanjosepro/clean-up-table" - }, - "policy": { - "type": "diffusion", - "n_obs_steps": 1, - "normalization_mapping": { - "VISUAL": "MEAN_STD", - "STATE": "MEAN_STD", - "ACTION": "MEAN_STD" - }, - "input_features": { - "observation.images.right_wrist_camera": { - "type": "VISUAL", - "shape": [ - 3, - 256, - 256 - ] - }, - "observation.images.right_third_person_camera": { - "type": "VISUAL", - "shape": [ - 3, - 256, - 256 - ] - }, - "observation.state": { - "type": "STATE", - "shape": [ - 7 - ] - } - }, - "output_features": { - "action": { - "type": "ACTION", - "shape": [ - 7 - ] - } - }, - "device": "cuda", - "use_amp": false, - "push_to_hub": false, - "repo_id": null, - "private": null, - "tags": null, - "license": null, - "horizon": 16, - "n_action_steps": 8, - "drop_n_last_frames": 7, - "vision_backbone": "resnet18", - "crop_shape": [ - 224, - 224 - ], - "crop_is_random": true, - "pretrained_backbone_weights": null, - "use_group_norm": true, - "spatial_softmax_num_keypoints": 32, - "use_separate_rgb_encoder_per_camera": false, - "down_dims": [ - 512, - 1024, - 2048 - ], - "kernel_size": 5, - "n_groups": 8, - "diffusion_step_embed_dim": 128, - "use_film_scale_modulation": true, - "noise_scheduler_type": "DDPM", - "num_train_timesteps": 100, - "beta_schedule": "squaredcos_cap_v2", - "beta_start": 0.0001, - "beta_end": 0.02, - "prediction_type": "epsilon", - "clip_sample": true, - "clip_sample_range": 1.0, - "num_inference_steps": 10, - "do_mask_loss_for_padding": false, - "optimizer_lr": 0.0001, - "optimizer_betas": [ - 0.95, - 0.999 - ], - "optimizer_eps": 1e-08, - "optimizer_weight_decay": 1e-06, - "scheduler_name": "cosine", - "scheduler_warmup_steps": 500 - }, - "job_name": "diffusion", - "resume": false, - "seed": 0, - "steps": 20000, - "eval_freq": 1000, - "log_freq": 100, - "save_checkpoint": true, - "save_freq": 5000, - "num_workers": 4, - "batch_size": 64, - "use_policy_training_preset": true, - "optimizer": { - "type": "adam", - "lr": 0.0001, - "weight_decay": 1e-06, - "grad_clip_norm": 10.0, - "betas": [ - 0.95, - 0.999 - ], - "eps": 1e-08 - }, - "scheduler": { - "type": "diffuser", - "num_warmup_steps": 500, - "name": "cosine" - }, - "eval": { - "n_episodes": 50, - "batch_size": 50, - "use_async_envs": false - }, - "wandb": { - "enable": false - } -} diff --git a/crisp_gym/config/home.py b/crisp_gym/config/home.py index 43d7f3a..986164f 100644 --- a/crisp_gym/config/home.py +++ b/crisp_gym/config/home.py @@ -1,6 +1,8 @@ """Contains some home configurations.""" # TODO: make the configs robot specific +from enum import Enum + home_close_to_table = [ -1.73960110e-02, 9.55319758e-02, @@ -20,3 +22,18 @@ 1.68992915, 0.8040582, ] + + +class HomeConfig(Enum): + """Enum for different home configurations.""" + + CLOSE_TO_TABLE = home_close_to_table + FRONT_UP = home_front_up + + def randomize(self, noise: float = 0.01) -> list: + """Randomize the home configuration.""" + import numpy as np + + return ( + np.array(self.value) + np.random.uniform(-noise, noise, size=len(self.value)) + ).tolist() diff --git a/crisp_gym/manipulator_env.py b/crisp_gym/manipulator_env.py index dde3261..23e8fa8 100644 --- a/crisp_gym/manipulator_env.py +++ b/crisp_gym/manipulator_env.py @@ -39,7 +39,11 @@ from scipy.spatial.transform import Rotation from typing_extensions import override -from crisp_gym.manipulator_env_config import ManipulatorEnvConfig, make_env_config +from crisp_gym.manipulator_env_config import ( + ManipulatorEnvConfig, + ObservationKeys, + make_env_config, +) from crisp_gym.util.control_type import ControlType from crisp_gym.util.gripper_mode import ( GripperMode, @@ -110,7 +114,7 @@ def __init__(self, config: ManipulatorEnvConfig, namespace: str = ""): self.observation_space = gym.spaces.Dict( { **{ - f"observation.images.{camera.config.camera_name}": gym.spaces.Box( + f"{ObservationKeys.IMAGE_OBS}.{camera.config.camera_name}": gym.spaces.Box( low=np.zeros((*camera.config.resolution, 3), dtype=np.uint8), high=255 * np.ones((*camera.config.resolution, 3), dtype=np.uint8), dtype=np.uint8, @@ -119,7 +123,7 @@ def __init__(self, config: ManipulatorEnvConfig, namespace: str = ""): if camera.config.resolution is not None }, # Combined state: cartesian pose (6D) - "observation.state.cartesian": gym.spaces.Box( + ObservationKeys.CARTESIAN_OBS: gym.spaces.Box( low=np.concatenate( [ -np.ones((6,), dtype=np.float32), # cartesian pose @@ -133,13 +137,13 @@ def __init__(self, config: ManipulatorEnvConfig, namespace: str = ""): dtype=np.float32, ), # Gripper state - "observation.state.gripper": gym.spaces.Box( + ObservationKeys.GRIPPER_OBS: gym.spaces.Box( low=np.array([0.0], dtype=np.float32), high=np.array([1.0], dtype=np.float32), dtype=np.float32, ), # Joint state - "observation.state.joint": gym.spaces.Box( + ObservationKeys.JOINT_OBS: gym.spaces.Box( low=np.ones((self.config.robot_config.num_joints(),), dtype=np.float32) * -np.pi, high=np.ones((self.config.robot_config.num_joints(),), dtype=np.float32) @@ -150,7 +154,7 @@ def __init__(self, config: ManipulatorEnvConfig, namespace: str = ""): "task": gym.spaces.Text(max_length=256), # Sensor data **{ - f"observation.state.sensor_{sensor.config.name}": gym.spaces.Box( + f"{ObservationKeys.SENSOR_OBS}_{sensor.config.name}": gym.spaces.Box( low=-np.inf * np.ones(sensor.config.shape, dtype=np.float32), high=np.inf * np.ones(sensor.config.shape, dtype=np.float32), dtype=np.float32, @@ -219,12 +223,7 @@ def _get_obs(self) -> dict: """Retrieve the current observation from the robot in LeRobot format. Returns: - dict: A dictionary containing the current sensor and state information in LeRobot format: - - 'observation.images.{camera_name}': RGB image from each configured camera. - - 'observation.state': Combined state vector (cartesian pose + gripper). - - 'observation.state.joint': Current joint configuration of the robot in radians. - - 'observation.state.sensor_{sensor_name}': Sensor values. - - 'task': Task description (empty string for now). + dict: A dictionary containing the current sensor and state information. """ obs = {} @@ -246,21 +245,23 @@ def _get_obs(self) -> dict: ) # Cartesian pose - obs["observation.state.cartesian"] = cartesian_pose.astype(np.float32) + obs[ObservationKeys.CARTESIAN_OBS] = cartesian_pose.astype(np.float32) # Gripper state - obs["observation.state.gripper"] = gripper_value.astype(np.float32) + obs[ObservationKeys.GRIPPER_OBS] = gripper_value.astype(np.float32) # Joint state - obs["observation.state.joint"] = self.robot.joint_values + obs[ObservationKeys.JOINT_OBS] = self.robot.joint_values # Camera images for camera in self.cameras: - obs[f"observation.images.{camera.config.camera_name}"] = camera.current_image + image_key = f"{ObservationKeys.IMAGE_OBS}.{camera.config.camera_name}" + obs[image_key] = camera.current_image # Sensor data for sensor in self.sensors: - obs[f"observation.state.sensor_{sensor.config.name}"] = sensor.value + sensor_key = f"{ObservationKeys.SENSOR_OBS}_{sensor.config.name}" + obs[sensor_key] = sensor.value return obs @@ -374,10 +375,25 @@ def home(self, home_config: list[float] | None = None, blocking: bool = True): if not blocking: self.switch_to_default_controller() + def get_metadata(self) -> dict: + """Generate metadata for the environment. + + Returns: + dict: Metadata dictionary. + """ + from importlib.metadata import version + + return { + "crisp_gym_version": version("crisp_gym"), + "crisp_py_version": version("crisp_python"), + "control_type": self.ctrl_type.name, + "env_config": self.config.get_metadata(), + } + def move_to( self, position: List | NDArray | None = None, - pose: List | NDArray | None = None, + pose: Pose | None = None, speed: float = 0.05, ): """Move the robot to a specified position or pose. @@ -432,7 +448,7 @@ def __init__(self, config: ManipulatorEnvConfig, namespace: str = ""): self.observation_space: gym.spaces.Dict = gym.spaces.Dict( { **self.observation_space.spaces, - "observation.state.target": gym.spaces.Box( + ObservationKeys.TARGET_OBS: gym.spaces.Box( low=np.ones((6,), dtype=np.float32) * -np.pi, high=np.ones((6,), dtype=np.float32) * np.pi, dtype=np.float32, @@ -537,7 +553,7 @@ def __init__(self, config: ManipulatorEnvConfig, namespace: str = ""): self.observation_space: gym.spaces.Dict = gym.spaces.Dict( { **self.observation_space.spaces, - "observation.state.target": gym.spaces.Box( + ObservationKeys.TARGET_OBS: gym.spaces.Box( low=np.ones((self.num_joints,), dtype=np.float32) * -np.pi, high=np.ones((self.num_joints,), dtype=np.float32) * np.pi, dtype=np.float32, diff --git a/crisp_gym/manipulator_env_config.py b/crisp_gym/manipulator_env_config.py index fca824b..24d76e8 100644 --- a/crisp_gym/manipulator_env_config.py +++ b/crisp_gym/manipulator_env_config.py @@ -16,6 +16,29 @@ from crisp_gym.util.gripper_mode import GripperMode +class ObservationKeys: + """Standardized keys for observations in manipulator environments.""" + + STATE_OBS = "observation.state" + + GRIPPER_OBS = STATE_OBS + ".gripper" + JOINT_OBS = STATE_OBS + ".joints" + CARTESIAN_OBS = STATE_OBS + ".cartesian" + TARGET_OBS = STATE_OBS + ".target" + SENSOR_OBS = STATE_OBS + ".sensors" + + IMAGE_OBS = "observation.images" + + +ALLOWED_STATE_OBS_KEYS = { + ObservationKeys.GRIPPER_OBS, + ObservationKeys.JOINT_OBS, + ObservationKeys.CARTESIAN_OBS, + ObservationKeys.TARGET_OBS, + ObservationKeys.SENSOR_OBS, +} + + @dataclass(kw_only=True) class ManipulatorEnvConfig(ABC): """Manipulator Gym Environment Configuration. @@ -68,6 +91,23 @@ def __post_init__(self): if isinstance(self.gripper_mode, str): self.gripper_mode = GripperMode(self.gripper_mode) + def get_metadata(self) -> dict: + """Get metadata about the environment configuration. + + Returns: + dict: Metadata dictionary containing control frequency, robot type, gripper type, and camera names. + """ + return { + "robot_config": self.robot_config.__dict__, + "gripper_config": self.gripper_config.__dict__ if self.gripper_config else "None", + "camera_config": [camera.__dict__ for camera in self.camera_configs], + "sensor_config": [sensor.__dict__ for sensor in self.sensor_configs], + "gripper_mode": str(self.gripper_mode), + "gripper_threshold": self.gripper_threshold, + "cartesian_control_param_config": str(self.cartesian_control_param_config), + "joint_control_param_config": str(self.joint_control_param_config), + } + @classmethod def from_yaml(cls, yaml_path: Path, **overrides) -> "ManipulatorEnvConfig": # noqa: ANN003 """Load config from YAML file with optional overrides. diff --git a/crisp_gym/record/record_functions.py b/crisp_gym/record/record_functions.py index 782148a..4fb64a4 100644 --- a/crisp_gym/record/record_functions.py +++ b/crisp_gym/record/record_functions.py @@ -14,7 +14,8 @@ from lerobot.policies.factory import get_policy_class from crisp_gym.util.control_type import ControlType -from crisp_gym.util.lerobot_features import numpy_obs_to_torch +from crisp_gym.util.lerobot_features import concatenate_state_features, numpy_obs_to_torch +from crisp_gym.util.setup_logger import setup_logging if TYPE_CHECKING: from multiprocessing.connection import Connection @@ -24,6 +25,9 @@ from crisp_gym.teleop.teleop_sensor_stream import TeleopStreamedPose +logger = logging.getLogger(__name__) + + def make_teleop_streamer_fn(env: ManipulatorCartesianEnv, leader: TeleopStreamedPose) -> Callable: """Create a teleoperation function for the leader robot using streamed pose data.""" prev_pose = leader.last_pose @@ -166,6 +170,9 @@ def inference_worker( policy.to(device).eval() warmup_obs_raw = env.observation_space.sample() + logging.info(f"[Inference] Warm-up observation keys: {list(warmup_obs_raw.keys())}") + warmup_obs_raw["observation.state"] = concatenate_state_features(warmup_obs_raw) + logging.info(f"[Inference] Warm-up observation keys: {list(warmup_obs_raw.keys())}") warmup_obs = numpy_obs_to_torch(warmup_obs_raw) with torch.inference_mode(): @@ -221,15 +228,19 @@ def _fn() -> tuple: """ obs_raw = env.get_obs() + from crisp_gym.util.lerobot_features import concatenate_state_features + + obs_raw["observation.state"] = concatenate_state_features(obs_raw) + # Send observation to inference worker and receive action parent_conn.send(obs_raw) action = parent_conn.recv().squeeze(0).to("cpu").numpy() - logging.debug(f"Action: {action}") + logger.debug(f"Action: {action}") try: env.step(action, block=False) except Exception as e: - logging.exception(f"Error during environment step: {e}") + logger.exception(f"Error during environment step: {e}") return obs_raw, action diff --git a/crisp_gym/record/recording_manager.py b/crisp_gym/record/recording_manager.py index 34c5964..e2e15e1 100644 --- a/crisp_gym/record/recording_manager.py +++ b/crisp_gym/record/recording_manager.py @@ -76,6 +76,11 @@ def __init__( ) self.writer.start() + @property + def dataset_directory(self) -> Path: + """Return the path to the dataset directory.""" + return Path(HF_LEROBOT_HOME / self.config.repo_id) + @property def num_episodes(self) -> int: """Return the number of episodes to record.""" diff --git a/crisp_gym/util/lerobot_features.py b/crisp_gym/util/lerobot_features.py index 1e90cf6..bffdd61 100644 --- a/crisp_gym/util/lerobot_features.py +++ b/crisp_gym/util/lerobot_features.py @@ -67,6 +67,7 @@ def get_features( state_feature_length = 0 state_feature_names = [] + # TODO: unify with crisp_gym observation keys for feature_key in env.observation_space.keys(): if ignore_keys and feature_key in ignore_keys: continue @@ -104,7 +105,7 @@ def get_features( elif feature_key.startswith("task"): continue # Task features are handled separately - elif feature_key.startswith("observation.images."): + elif feature_key.startswith("observation.images"): if not use_video: features[feature_key] = { "dtype": "image", @@ -169,7 +170,9 @@ def construct_state_feature(length: int, names: list[str]) -> Dict[str, Any]: } -def concatenate_state_features(obs: Dict[str, Any], features: Dict[str, Dict]) -> np.ndarray: +def concatenate_state_features( + obs: Dict[str, Any], features: Dict[str, Dict] | None = None +) -> np.ndarray: """Concatenate individual state features into a single state vector. This function takes the individual state components from the observation dictionary @@ -184,29 +187,27 @@ def concatenate_state_features(obs: Dict[str, Any], features: Dict[str, Dict]) - """ state_components = [] - for feature_name in features: + for feature_name in obs: if not feature_name.startswith("observation.state"): continue if feature_name == "observation.state": continue # Skip the combined state feature - if feature_name in obs: - value = obs[feature_name] - if isinstance(value, np.ndarray): - state_components.append(value.astype(np.float32)) - else: - state_components.append(np.array(value, dtype=np.float32)) + value = obs[feature_name] + if isinstance(value, np.ndarray): + state_components.append(value.astype(np.float32)) else: - raise ValueError(f"Missing required state component: {feature_name}") + state_components.append(np.array(value, dtype=np.float32)) concatenated_state = np.concatenate(state_components, axis=0) - expected_length = features["observation.state"]["shape"][0] + if features: + expected_length = features["observation.state"]["shape"][0] - if concatenated_state.shape[0] != expected_length: - raise ValueError( - f"Concatenated state length {concatenated_state.shape[0]} does not match expected length {expected_length}." - ) + if concatenated_state.shape[0] != expected_length: + raise ValueError( + f"Concatenated state length {concatenated_state.shape[0]} does not match expected length {expected_length}." + ) return concatenated_state diff --git a/scripts/deploy_policy.py b/scripts/deploy_policy.py index ec3d6b2..d8cd526 100644 --- a/scripts/deploy_policy.py +++ b/scripts/deploy_policy.py @@ -164,6 +164,8 @@ try: ctrl_type = "cartesian" if not args.joint_control else "joint" env = make_env(args.env_config, control_type=ctrl_type, namespace=args.env_namespace) + env.config.robot_config.home_config = home_close_to_table + env.config.robot_config.time_to_home = 2.0 # %% Prepare the dataset features = get_features(env) @@ -212,6 +214,7 @@ def on_end(): """Hook function to be called when stopping the recording.""" env.robot.reset_targets() env.robot.home(blocking=False) + env.gripper.open() with evaluator.start_eval(overwrite=True, activate=args.evaluate): with recording_manager: diff --git a/scripts/record_lerobot_format_leader_follower.py b/scripts/record_lerobot_format_leader_follower.py index 2563acc..a815153 100644 --- a/scripts/record_lerobot_format_leader_follower.py +++ b/scripts/record_lerobot_format_leader_follower.py @@ -1,13 +1,14 @@ """Script showcasing how to record data in Lerobot Format.""" import argparse +import json import logging import numpy as np -import rclpy # noqa: F401 +import rclpy import crisp_gym # noqa: F401 -from crisp_gym.config.home import home_close_to_table +from crisp_gym.config.home import HomeConfig from crisp_gym.config.path import CRISP_CONFIG_PATH from crisp_gym.manipulator_env import ManipulatorCartesianEnv, make_env from crisp_gym.manipulator_env_config import list_env_configs @@ -112,6 +113,12 @@ action="store_true", help="Whether to use streamed teleop (e.g., from a phone or VR device) for the leader robot.", ) +parser.add_argument( + "--home-config-noise", + type=float, + default=0.0, + help="Noise to add to the home configuration when homing the robots to randomize the position a bit.", +) args = parser.parse_args() @@ -168,6 +175,8 @@ control_type=ctrl_type, namespace=args.follower_namespace, ) + env.config.robot_config.home_config = HomeConfig.CLOSE_TO_TABLE.value + env.config.robot_config.time_to_home = 2.0 leader: TeleopRobot | TeleopStreamedPose | None = None if args.use_streamed_teleop: @@ -176,7 +185,7 @@ else: leader = make_leader(args.leader_config, namespace=args.leader_namespace) leader.wait_until_ready() - leader.config.leader.home_config = home_close_to_table + leader.config.leader.home_config = HomeConfig.CLOSE_TO_TABLE.value leader.config.leader.time_to_home = 2.0 logger.info("Using teleop robot for the leader robot. Leader is ready.") @@ -201,15 +210,25 @@ push_to_hub=args.push_to_hub, ) recording_manager.wait_until_ready() + logger.info("Recording manager is ready.") + + env_metadata = env.get_metadata() + + with open(recording_manager.dataset_directory / "meta" / "crisp_meta.json", "w") as f: + json.dump(env_metadata, f, indent=4) + + logger.info( + f"Environment metadata saved to {recording_manager.dataset_directory / 'meta' / 'crisp_meta.json'}" + ) logger.info("Homing both robots before starting with recording.") # Prepare environment and leader if isinstance(leader, TeleopRobot): leader.prepare_for_teleop() - env.robot.config.home_config = home_close_to_table - env.robot.config.time_to_home = 2.0 - env.home() + + env.wait_until_ready() + env.home(home_config=HomeConfig.CLOSE_TO_TABLE.randomize(noise=args.home_config_noise)) env.reset() tasks = list(args.tasks) @@ -236,10 +255,12 @@ def on_start(): def on_end(): """Hook function to be called when stopping the recording.""" env.robot.reset_targets() - env.robot.home(blocking=False) + random_home = HomeConfig.CLOSE_TO_TABLE.randomize(noise=args.home_config_noise) + env.robot.home(blocking=False, home_config=random_home) if isinstance(leader, TeleopRobot): leader.robot.reset_targets() - leader.robot.home(blocking=False) + leader.robot.home(blocking=False, home_config=random_home) + env.gripper.open() with recording_manager: while not recording_manager.done():