From 49716c468dd19927b6db25d4f283006e701fe49f Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Wed, 3 Apr 2024 15:29:09 +0200 Subject: [PATCH 1/9] fix robomimic_graph_dataset multirobot variant --- imitation/config/task/transport_graph.yaml | 22 ++++-- imitation/dataset/robomimic_graph_dataset.py | 75 +++++++++++--------- 2 files changed, 58 insertions(+), 39 deletions(-) diff --git a/imitation/config/task/transport_graph.yaml b/imitation/config/task/transport_graph.yaml index c9b3606..611a82f 100644 --- a/imitation/config/task/transport_graph.yaml +++ b/imitation/config/task/transport_graph.yaml @@ -3,7 +3,9 @@ task_name: &task_name transport dataset_type: &dataset_type ph dataset_path: &dataset_path /home/caio/workspace/GraphDiffusionImitate/data/${task.task_name}/${task.dataset_type}/low_dim_v141.hdf5 -max_steps: 1000 +max_steps: ${max_steps} + +control_mode: "JOINT_POSITION" obs_dim: 101 action_dim: 18 @@ -56,11 +58,20 @@ env_runner: action_horizon: ${action_horizon} obs_horizon: ${obs_horizon} render: ${render} + output_video: ${output_video} env: - _target_: imitation.env.robomimic_lowdim_wrapper.RobomimicLowdimWrapper + _target_: imitation.env.robomimic_graph_wrapper.RobomimicGraphWrapper + object_state_sizes: *object_state_sizes + object_state_keys: *object_state_keys max_steps: ${task.max_steps} task: "TwoArmTransport" + has_renderer: ${render} robots: *robots + output_video: ${output_video} + control_mode: ${task.control_mode} + controller_config: + interpolation: "linear" + ramp_ratio: 0.2 dataset: @@ -71,5 +82,8 @@ dataset: obs_horizon: ${obs_horizon} object_state_sizes: *object_state_sizes object_state_keys: *object_state_keys - mode: "task-joint-space" - robots: *robots \ No newline at end of file + robots: *robots + control_mode: ${task.control_mode} + base_link_shift: + - [0.0, 0.0, 0.0] + - [0.0, 0.0, 0.0] \ No newline at end of file diff --git a/imitation/dataset/robomimic_graph_dataset.py b/imitation/dataset/robomimic_graph_dataset.py index 7d981eb..285e87c 100644 --- a/imitation/dataset/robomimic_graph_dataset.py +++ b/imitation/dataset/robomimic_graph_dataset.py @@ -138,12 +138,12 @@ def _get_node_feats_horizon(self, data, idx, horizon): Calculate node features for self.obs_horizon time steps ''' node_feats = [] - episode_length = data["object"].shape[0] # calculate node features for timesteps idx to idx + horizon t_vals = list(range(idx, idx + horizon)) node_feats = self._get_node_feats(data, t_vals) return node_feats + @lru_cache(maxsize=None) def _get_edge_attrs(self, edge_index): ''' Attribute edge types to edges @@ -161,7 +161,7 @@ def _get_edge_attrs(self, edge_index): edge_attrs.append(self.OBJECT_ROBOT_EDGE) return torch.tensor(edge_attrs, dtype=torch.long) - + @lru_cache(maxsize=None) def _get_edge_index(self, num_nodes): ''' Returns edge index for graph. @@ -173,9 +173,9 @@ def _get_edge_index(self, num_nodes): for idx in range(eef_idx): edge_index.append([idx, idx+1]) - # Connectivity of all other nodes to the last node of robot + # Connectivity of all other nodes to all robot nodes for idx in range(eef_idx + 1, num_nodes): - edge_index.append([idx, eef_idx]) + edge_index.append(torch.tensor([node_idx, idx]) for node_idx in range(eef_idx + 1)) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return edge_index @@ -289,10 +289,11 @@ def __init__(self, object_state_sizes, object_state_keys, robots, + control_mode, pred_horizon=1, obs_horizon=1, - node_feature_dim = 8, - mode="joint-space"): # TODO update according to RobomimicGraphDataset + node_feature_dim = 2, + base_link_shift=[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]): self.num_robots = len(robots) self.eef_idx = [0, 7, 13] super().__init__(dataset_path=dataset_path, @@ -301,48 +302,53 @@ def __init__(self, object_state_keys=object_state_keys, pred_horizon=pred_horizon, obs_horizon=obs_horizon, - mode=mode, + control_mode=control_mode, node_feature_dim = node_feature_dim, ) + def _get_node_pos(self, data, t): + for i in range(self.num_robots): + node_pos = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"][t], *data[f"robot{i}_gripper_qpos"][t]]) + node_pos[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) + # TODO find out how the robots are rotated in the transport environment + # use rotation transformer to convert quaternion to 6d rotation + node_pos = torch.cat([node_pos[:,:3], self.rotation_transformer.forward(node_pos[:,3:])], dim=1) + obj_pos_tensor = self._get_object_pos(data, t) + node_pos = torch.cat((node_pos, obj_pos_tensor), dim=0) + return node_pos - def _get_node_feats(self, data, t): + + def _get_node_feats(self, data, t_vals): ''' Here, robot0_eef_pos, robot1_eef_pos, ... are used as node features. ''' + T = len(t_vals) node_feats = [] - if self.mode == "end-effector": + if self.control_mode == "OSC_POSE": for i in range(self.num_robots): - node_feats.append(torch.cat([torch.tensor(data[f"robot{i}_eef_pos"][t - 1:t][0]), torch.tensor(data[f"robot{i}_eef_quat"][t - 1:t][0])], dim=0)) - node_feats = torch.stack(node_feats) - elif self.mode == "task-space": - for j in range(self.num_robots): - node_feats.append(calculate_panda_joints_positions([*data[f"robot{j}_joint_pos"][t], *data[f"robot{j}_gripper_qpos"][t]])) - node_feats = torch.cat(node_feats) - elif self.mode == "joint-space": + node_feats.append(torch.cat([torch.tensor(data["robot0_eef_pos"][t_vals]), torch.tensor(data["robot0_eef_quat"][t_vals])], dim=0)) + elif self.control_mode == "JOINT_POSITION": for i in range(self.num_robots): node_feats.append(torch.cat([ - torch.tensor(data[f"robot{i}_joint_pos"][t - 1:t][0]).reshape(1,-1), - torch.zeros((6,7))])) # complete with zeros to match task-space dimensionality - node_feats = torch.cat(node_feats) - elif self.mode == "task-joint-space": + torch.tensor(data[f"robot{i}_joint_pos"][t_vals]), + torch.tensor(data[f"robot{i}_gripper_qpos"][t_vals])], dim=1).T.unsqueeze(2)) + elif self.control_mode == "JOINT_VELOCITY": for i in range(self.num_robots): - node_feats.append(torch.cat([calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"][t], *data[f"robot{i}_gripper_qpos"][t]]), - torch.tensor(data[f"robot{i}_joint_pos"][t - 1:t]).reshape(-1,1)], dim=1)) - node_feats = torch.cat(node_feats, dim=0) - - else: - raise NotImplementedError + node_feats.append(torch.cat([ + torch.tensor(data[f"robot{i}_joint_vel"][t_vals]), + torch.tensor(data[f"robot{i}_gripper_qvel"][t_vals])], dim=1).T.unsqueeze(2)) + node_feats = torch.cat(node_feats, dim=0) # [num_robots*num_joints, T, 1] # add dimension for NODE_TYPE, which is 0 for robot and 1 for objects - node_feats = torch.cat((node_feats, self.ROBOT_NODE_TYPE*torch.ones((node_feats.shape[0],1))), dim=1) + node_feats = torch.cat((node_feats, self.ROBOT_NODE_TYPE*torch.ones((node_feats.shape[0],node_feats.shape[1],1))), dim=2) - obj_state_tensor = self._get_object_feats(self.num_objects, self.node_feature_dim, self.OBJECT_NODE_TYPE) + obj_state_tensor = self._get_object_feats(self.num_objects, self.node_feature_dim, self.OBJECT_NODE_TYPE, T) node_feats = torch.cat((node_feats, obj_state_tensor), dim=0) return node_feats + @lru_cache(maxsize=None) def _get_edge_index(self, num_nodes): ''' Returns edge index for graph. @@ -350,17 +356,16 @@ def _get_edge_index(self, num_nodes): - all object nodes are connected to the last robot node (end-effector) ''' assert len(self.eef_idx) == self.num_robots + 1 - edge_index = [] + edge_index = [[self.eef_idx[0], self.eef_idx[1] + 1]] # robot0 base link to robot1 base link - for id_robot in range(1, len(self.eef_idx)): - for idx in range(self.eef_idx[id_robot-1], self.eef_idx[id_robot]): - edge_index.append([idx, idx+1]) - # Connectivity of all other nodes to the last node of all robots - for idx in range(self.eef_idx[self.num_robots] + 1, num_nodes): - edge_index.append([self.eef_idx[id_robot], idx]) + edge_index += [[idx, idx+1] for id_robot in range(1, len(self.eef_idx)-1) for idx in range(self.eef_idx[id_robot-1], self.eef_idx[id_robot])] + # Connectivity of all other nodes to all robot nodes + edge_index += [[node_idx, idx] for idx in range(self.eef_idx[self.num_robots] + 1, num_nodes) for node_idx in range(self.eef_idx[self.num_robots] + 1)] + # edge_index.append(torch.tensor([node_idx, idx]) for node_idx in range(self.eef_idx[self.num_robots] + 1)) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return edge_index + @lru_cache(maxsize=None) def _get_edge_attrs(self, edge_index): ''' Attribute edge types to edges From b998a0fb9933af1f5d1366b76906e794314d7420 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Wed, 3 Apr 2024 18:33:51 +0200 Subject: [PATCH 2/9] add base link shift to square environment --- imitation/config/task/square_graph.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/imitation/config/task/square_graph.yaml b/imitation/config/task/square_graph.yaml index 449d171..174ac97 100644 --- a/imitation/config/task/square_graph.yaml +++ b/imitation/config/task/square_graph.yaml @@ -44,6 +44,7 @@ env_runner: controller_config: interpolation: "linear" ramp_ratio: 0.2 + base_link_shift: [-0.56, 0, 0.912] dataset: _target_: imitation.dataset.robomimic_graph_dataset.RobomimicGraphDataset @@ -53,4 +54,5 @@ dataset: obs_horizon: ${obs_horizon} object_state_sizes: *object_state_sizes object_state_keys: *object_state_keys - control_mode: ${task.control_mode} \ No newline at end of file + control_mode: ${task.control_mode} + base_link_shift: [-0.56, 0, 0.912] \ No newline at end of file From e5013ac9cdeb6c88007c1c764064e447751f8e71 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Wed, 3 Apr 2024 18:34:09 +0200 Subject: [PATCH 3/9] add correct base link shifts to transport, and base link rotations --- imitation/config/task/transport_graph.yaml | 7 +++++-- imitation/dataset/robomimic_graph_dataset.py | 18 ++++++++++++++---- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/imitation/config/task/transport_graph.yaml b/imitation/config/task/transport_graph.yaml index 611a82f..69e9ea1 100644 --- a/imitation/config/task/transport_graph.yaml +++ b/imitation/config/task/transport_graph.yaml @@ -85,5 +85,8 @@ dataset: robots: *robots control_mode: ${task.control_mode} base_link_shift: - - [0.0, 0.0, 0.0] - - [0.0, 0.0, 0.0] \ No newline at end of file + - [0.0, -0.81, 0.912] + - [0.0, 0.81, 0.912] + base_link_rotation: + - [0.707107, 0, 0, 0.707107] + - [0.707107, 0, 0, -0.707107] \ No newline at end of file diff --git a/imitation/dataset/robomimic_graph_dataset.py b/imitation/dataset/robomimic_graph_dataset.py index 285e87c..0adf091 100644 --- a/imitation/dataset/robomimic_graph_dataset.py +++ b/imitation/dataset/robomimic_graph_dataset.py @@ -9,6 +9,7 @@ from tqdm import tqdm from typing import List, Dict from functools import lru_cache +from scipy.spatial.transform import Rotation as R from diffusion_policy.model.common.rotation_transformer import RotationTransformer @@ -293,9 +294,12 @@ def __init__(self, pred_horizon=1, obs_horizon=1, node_feature_dim = 2, - base_link_shift=[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]): + base_link_shift=[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + base_link_rotation=[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]): self.num_robots = len(robots) self.eef_idx = [0, 7, 13] + self.BASE_LINK_ROTATION = base_link_rotation + super().__init__(dataset_path=dataset_path, action_keys=action_keys, object_state_sizes=object_state_sizes, @@ -304,13 +308,19 @@ def __init__(self, obs_horizon=obs_horizon, control_mode=control_mode, node_feature_dim = node_feature_dim, - ) - + base_link_shift=base_link_shift) + def _get_node_pos(self, data, t): for i in range(self.num_robots): node_pos = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"][t], *data[f"robot{i}_gripper_qpos"][t]]) - node_pos[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) + # rotate robot nodes + rotation_matrix = R.from_quat(self.BASE_LINK_ROTATION[i]) + node_pos[:,:3] = torch.matmul(node_pos[:,:3], torch.tensor(rotation_matrix.as_matrix())) + node_pos[:,3:] = torch.tensor((R.from_quat(node_pos[:,3:].detach().numpy()) * rotation_matrix).as_quat()) + # add base link shift + node_pos[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) + # TODO find out how the robots are rotated in the transport environment # use rotation transformer to convert quaternion to 6d rotation node_pos = torch.cat([node_pos[:,:3], self.rotation_transformer.forward(node_pos[:,3:])], dim=1) From 798708b043e4b4c17fa5ea8b679a3e6869d2b767 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Wed, 3 Apr 2024 19:04:06 +0200 Subject: [PATCH 4/9] add graph pos to both robots --- imitation/dataset/robomimic_graph_dataset.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/imitation/dataset/robomimic_graph_dataset.py b/imitation/dataset/robomimic_graph_dataset.py index 0adf091..3837eac 100644 --- a/imitation/dataset/robomimic_graph_dataset.py +++ b/imitation/dataset/robomimic_graph_dataset.py @@ -312,15 +312,17 @@ def __init__(self, def _get_node_pos(self, data, t): + node_pos = [] for i in range(self.num_robots): - node_pos = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"][t], *data[f"robot{i}_gripper_qpos"][t]]) + node_pos_robot = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"][t], *data[f"robot{i}_gripper_qpos"][t]]) # rotate robot nodes rotation_matrix = R.from_quat(self.BASE_LINK_ROTATION[i]) - node_pos[:,:3] = torch.matmul(node_pos[:,:3], torch.tensor(rotation_matrix.as_matrix())) - node_pos[:,3:] = torch.tensor((R.from_quat(node_pos[:,3:].detach().numpy()) * rotation_matrix).as_quat()) + node_pos_robot[:,:3] = torch.matmul(node_pos_robot[:,:3], torch.tensor(rotation_matrix.as_matrix())) + node_pos_robot[:,3:] = torch.tensor((R.from_quat(node_pos_robot[:,3:].detach().numpy()) * rotation_matrix).as_quat()) # add base link shift - node_pos[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) - + node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) + node_pos.append(node_pos_robot) + node_pos = torch.cat(node_pos, dim=0) # TODO find out how the robots are rotated in the transport environment # use rotation transformer to convert quaternion to 6d rotation node_pos = torch.cat([node_pos[:,:3], self.rotation_transformer.forward(node_pos[:,3:])], dim=1) From e419f9e8302da967f5bfc027d72bad7238f156bd Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Wed, 3 Apr 2024 22:29:28 +0200 Subject: [PATCH 5/9] match wrapper with multirobot dataset --- imitation/config/task/transport_graph.yaml | 19 ++++-- imitation/dataset/robomimic_graph_dataset.py | 3 +- imitation/env/robomimic_graph_wrapper.py | 69 ++++++++++++-------- 3 files changed, 56 insertions(+), 35 deletions(-) diff --git a/imitation/config/task/transport_graph.yaml b/imitation/config/task/transport_graph.yaml index 69e9ea1..99d4499 100644 --- a/imitation/config/task/transport_graph.yaml +++ b/imitation/config/task/transport_graph.yaml @@ -52,6 +52,13 @@ object_state_keys: &object_state_keys - trash_pos - trash_quat +base_link_shift: + - [0.0, -0.81, 0.912] + - [0.0, 0.81, 0.912] +base_link_rotation: + - [0.707107, 0, 0, 0.707107] + - [0.707107, 0, 0, -0.707107] + env_runner: _target_: imitation.env_runner.robomimic_lowdim_runner.RobomimicEnvRunner output_dir: ${output_dir} @@ -72,7 +79,9 @@ env_runner: controller_config: interpolation: "linear" ramp_ratio: 0.2 - + base_link_shift: ${task.base_link_shift} + base_link_rotation: ${task.base_link_rotation} + dataset: _target_: imitation.dataset.robomimic_graph_dataset.MultiRobotGraphDataset @@ -84,9 +93,5 @@ dataset: object_state_keys: *object_state_keys robots: *robots control_mode: ${task.control_mode} - base_link_shift: - - [0.0, -0.81, 0.912] - - [0.0, 0.81, 0.912] - base_link_rotation: - - [0.707107, 0, 0, 0.707107] - - [0.707107, 0, 0, -0.707107] \ No newline at end of file + base_link_shift: ${task.base_link_shift} + base_link_rotation: ${task.base_link_rotation} \ No newline at end of file diff --git a/imitation/dataset/robomimic_graph_dataset.py b/imitation/dataset/robomimic_graph_dataset.py index 3837eac..0a34365 100644 --- a/imitation/dataset/robomimic_graph_dataset.py +++ b/imitation/dataset/robomimic_graph_dataset.py @@ -90,7 +90,7 @@ def _get_object_pos(self, data, t): i = 0 for object_state in object_state_items: if "quat" in object_state: - assert self.object_state_sizes[object_state] == 4 + assert self.object_state_sizes[object_state] == 4, "Quaternion must have size 4" rot = self.rotation_transformer.forward(torch.tensor(data["object"][t][i:i + self.object_state_sizes[object_state]])) obj_state_tensor[object,i:i + 6] = rot else: @@ -323,7 +323,6 @@ def _get_node_pos(self, data, t): node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) node_pos.append(node_pos_robot) node_pos = torch.cat(node_pos, dim=0) - # TODO find out how the robots are rotated in the transport environment # use rotation transformer to convert quaternion to 6d rotation node_pos = torch.cat([node_pos[:,:3], self.rotation_transformer.forward(node_pos[:,3:])], dim=1) obj_pos_tensor = self._get_object_pos(data, t) diff --git a/imitation/env/robomimic_graph_wrapper.py b/imitation/env/robomimic_graph_wrapper.py index 4dc1a17..28132f1 100644 --- a/imitation/env/robomimic_graph_wrapper.py +++ b/imitation/env/robomimic_graph_wrapper.py @@ -5,7 +5,7 @@ import robosuite as suite from robosuite.controllers import load_controller_config from robosuite.wrappers.gym_wrapper import GymWrapper - +from scipy.spatial.transform import Rotation as R import torch import torch_geometric @@ -52,7 +52,8 @@ def __init__(self, output_video=False, control_mode="JOINT_VELOCITY", controller_config=None, - base_link_shift=[0.0, 0.0, 0.0] + base_link_shift=[0.0, 0.0, 0.0], + base_link_rotation=[[0.0, 0.0, 0.0, 1.0]] ): ''' Environment wrapper for Robomimic's GraphDiffusionImitate dataset in the same Graph representation as @@ -92,6 +93,7 @@ def __init__(self, self.NUM_GRAPH_NODES = 9 + self.num_objects # TODO add multi-robot support self.BASE_LINK_SHIFT = base_link_shift + self.BASE_LINK_ROTATION = base_link_rotation self.ROBOT_NODE_TYPE = 1 self.OBJECT_NODE_TYPE = -1 @@ -132,24 +134,38 @@ def _get_object_feats(self, data): return obj_state_tensor def _get_object_pos(self, data): - obj_state_tensor = torch.zeros((self.num_objects, 7)) # 3 for position, 4 for quaternion + obj_state_tensor = torch.zeros((self.num_objects, 9)) # 3 for position, 6 for rotation for object, object_state_items in enumerate(self.object_state_keys.values()): i = 0 for object_state in object_state_items: - obj_state_tensor[object,i:i + self.object_state_sizes[object_state]] = torch.from_numpy(data["object"][i:i + self.object_state_sizes[object_state]]) + if "quat" in object_state: + assert self.object_state_sizes[object_state] == 4, "Quaternion must have size 4" + rot = self.rotation_transformer.forward(torch.tensor(data["object"][i:i + self.object_state_sizes[object_state]])) + obj_state_tensor[object,i:i + 6] = rot + else: + obj_state_tensor[object,i:i + self.object_state_sizes[object_state]] = torch.from_numpy(data["object"][i:i + self.object_state_sizes[object_state]]) + i += self.object_state_sizes[object_state] return obj_state_tensor def _get_node_pos(self, data): - node_pos = calculate_panda_joints_positions([*data["robot0_joint_pos"], *data["robot0_gripper_qpos"]]) - node_pos[:,:3] += torch.tensor(self.BASE_LINK_SHIFT) - obj_pos_tensor = self._get_object_pos(data) - node_pos = torch.cat((node_pos, obj_pos_tensor), dim=0) + node_pos = [] + for i in range(len(self.robots)): + node_pos_robot = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"], *data[f"robot{i}_gripper_qpos"]]) + rotation_matrix = R.from_quat(self.BASE_LINK_ROTATION[i]) + node_pos_robot[:,:3] = torch.matmul(node_pos_robot[:,:3], torch.tensor(rotation_matrix.as_matrix())) + node_pos_robot[:,3:] = torch.tensor((R.from_quat(node_pos_robot[:,3:].detach().numpy()) * rotation_matrix).as_quat()) + # add base link shift + node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) + node_pos.append(node_pos_robot) + node_pos = torch.cat(node_pos, dim=0) # use rotation transformer to convert quaternion to 6d rotation node_pos = torch.cat([node_pos[:,:3], self.rotation_transformer.forward(node_pos[:,3:])], dim=1) + obj_pos_tensor = self._get_object_pos(data) + node_pos = torch.cat((node_pos, obj_pos_tensor), dim=0) return node_pos @@ -159,15 +175,14 @@ def _get_node_feats(self, data): Returns node features from data ''' node_feats = [] - if self.control_mode == "OSC_POSE": - node_feats = torch.cat([torch.tensor(data["robot0_eef_pos"]), torch.tensor(data["robot0_eef_quat"])], dim=0) - node_feats = node_feats.reshape(1, -1) # add dimension - elif self.control_mode == "JOINT_VELOCITY": - node_feats.append(torch.tensor([*data["robot0_joint_vel"], *data["robot0_gripper_qvel"]]).reshape(1,-1)) - node_feats = torch.cat(node_feats).T - elif self.control_mode == "JOINT_POSITION": - node_feats.append(torch.tensor([*data[f"robot0_joint_pos"], *data["robot0_gripper_qpos"]]).reshape(1,-1)) - node_feats = torch.cat(node_feats, dim=0).T + for i in range(len(self.robots)): + if self.control_mode == "OSC_POSE": + node_feats.append(torch.cat([torch.tensor(data[f"robot{i}_eef_pos"]), torch.tensor(data[f"robot{i}_eef_quat"])], dim=0).reshape(1, -1)) # add dimension + elif self.control_mode == "JOINT_VELOCITY": + node_feats.append(torch.tensor([*data[f"robot{i}_joint_vel"], *data[f"robot{i}_gripper_qvel"]]).reshape(1,-1).T) + elif self.control_mode == "JOINT_POSITION": + node_feats.append(torch.tensor([*data[f"robot{i}_joint_pos"], *data[f"robot{i}_gripper_qpos"]]).reshape(1,-1).T) + node_feats = torch.cat(node_feats, dim=0) return node_feats @@ -204,6 +219,7 @@ def _robosuite_obs_to_robomimic_graph(self, obs): ''' node_feats = torch.tensor([]) node_pos = torch.tensor([]) + robot_i_data = {} for i in range(len(self.robots)): j = i*39 @@ -221,15 +237,16 @@ def _robosuite_obs_to_robomimic_graph(self, obs): gripper_pose = obs[j + 35:j + 37] gripper_vel = obs[j + 37:j + 39] # Skip 2 - gripper joint velocities - robot_i_data = { - "robot0_joint_pos": robot_joint_pos, - "robot0_joint_vel": robot_joint_vel, - "robot0_eef_pos": eef_pose, - "robot0_eef_quat": eef_6d, - "robot0_gripper_qpos": gripper_pose, - "robot0_gripper_qvel": gripper_vel - } - node_feats = torch.cat([node_feats, self._get_node_feats(robot_i_data)], dim=0) + robot_i_data.update({ + f"robot{i}_joint_pos": robot_joint_pos, + f"robot{i}_joint_vel": robot_joint_vel, + f"robot{i}_eef_pos": eef_pose, + f"robot{i}_eef_quat": eef_6d, + f"robot{i}_gripper_qpos": gripper_pose, + f"robot{i}_gripper_qvel": gripper_vel + }) + + node_feats = torch.cat([node_feats, self._get_node_feats(robot_i_data)], dim=0) robot_i_data["object"] = obs[len(self.robots)*39:] node_pos = self._get_node_pos(robot_i_data) From c237bf6406032ce101fb3a44bc02c34c8afef930 Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Wed, 3 Apr 2024 22:54:14 +0200 Subject: [PATCH 6/9] hard-coded EEF only observations from FK --- .../config/policy/graph_ddpm_policy.yaml | 2 +- imitation/model/graph_diffusion.py | 6 ++- imitation/policy/graph_ddpm_policy.py | 38 ++++++++++++++++--- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/imitation/config/policy/graph_ddpm_policy.yaml b/imitation/config/policy/graph_ddpm_policy.yaml index 87c894a..15606f0 100644 --- a/imitation/config/policy/graph_ddpm_policy.yaml +++ b/imitation/config/policy/graph_ddpm_policy.yaml @@ -17,7 +17,7 @@ denoising_network: num_edge_types: ${policy.num_edge_types} num_layers: 3 hidden_dim: 128 - num_diffusion_iters: ${policy.num_diffusion_iters} + num_diffusion_steps: ${policy.num_diffusion_iters} ckpt_path: ./weights/diffusion_graph_policy_${task.task_name}_${task.dataset_type}_${task.control_mode}_${policy.num_diffusion_iters}iters.pt lr: 1e-4 batch_size: 32 diff --git a/imitation/model/graph_diffusion.py b/imitation/model/graph_diffusion.py index cf4b076..5dd9b94 100644 --- a/imitation/model/graph_diffusion.py +++ b/imitation/model/graph_diffusion.py @@ -356,15 +356,17 @@ def positionalencoding(self, lengths): return pes - def forward(self, x, edge_index, edge_attr, x_coord, cond, timesteps, batch=None): + def forward(self, x, edge_index, edge_attr, x_coord, cond, timesteps, batch=None, batch_cond=None, edge_index_cond=None, edge_attr_cond=None, x_coord_cond=None): # make sure x and edge_attr are of type float, for the MLPs x = x.float().to(self.device).flatten(start_dim=1) edge_attr = edge_attr.float().to(self.device).unsqueeze(-1) # add channel dimension edge_index = edge_index.to(self.device) cond = cond.float().to(self.device) x_coord = x_coord.float().to(self.device) + x_coord_cond = x_coord_cond.float().to(self.device) timesteps = timesteps.to(self.device) batch = batch.long().to(self.device) + batch_cond = batch_cond.long().to(self.device) batch_size = batch[-1] + 1 timesteps_embed = self.diffusion_step_encoder(self.pe[timesteps]) @@ -381,7 +383,7 @@ def forward(self, x, edge_index, edge_attr, x_coord, cond, timesteps, batch=None h_e = self.edge_embedding(edge_attr.reshape(-1, 1)) # FiLM generator - embed = self.cond_encoder(cond, edge_index, x_coord, edge_attr, batch=batch) + embed = self.cond_encoder(cond, edge_index_cond, x_coord_cond, edge_attr_cond, batch=batch_cond) embed = embed.reshape(self.num_layers, batch_size, 2, (self.hidden_dim + self.diffusion_step_embed_dim)) scales = embed[:,:,0,...] biases = embed[:,:,1,...] diff --git a/imitation/policy/graph_ddpm_policy.py b/imitation/policy/graph_ddpm_policy.py index 5b5994e..323629a 100644 --- a/imitation/policy/graph_ddpm_policy.py +++ b/imitation/policy/graph_ddpm_policy.py @@ -121,9 +121,13 @@ def get_action(self, obs_deque): edge_index = G_t.edge_index, edge_attr = G_t.edge_attr, x_coord = nobs[:,-1,:3], - cond = nobs[:,:,3:], + cond = nobs[-2:,:,3:], # only end-effector and object timesteps = torch.tensor([k], dtype=torch.long, device=self.device), - batch = torch.zeros(naction.shape[0], dtype=torch.long, device=self.device) + batch = torch.zeros(naction.shape[0], dtype=torch.long, device=self.device), + batch_cond = torch.zeros(nobs[-2:,:,3:].shape[0], dtype=torch.long, device=self.device), + edge_index_cond = torch.tensor([[0,1],[1,0]], dtype=torch.long, device=self.device), + edge_attr_cond = torch.ones(2, 1, device=self.device), + x_coord_cond = nobs[-2:,-1,3:] ) # inverse diffusion step (remove noise) @@ -170,7 +174,14 @@ def validate(self, dataset=None, model_path="last.pt"): # observation as FiLM conditioning # (B, node, obs_horizon, obs_dim) - obs_cond = nobs[:,:,3:] + obs_cond = nobs[:,:,3:] # only 6D rotation + # filter only 2 last nodes of each graph by batch.ptr + obs_cond = torch.cat([obs_cond[batch.ptr[i+1]-2:batch.ptr[i+1]] for i in range(B)], dim=0) + batch_cond = torch.cat([batch.batch[batch.ptr[i+1]-2:batch.ptr[i+1]] for i in range(B)], dim=0) + x_coord_cond = torch.cat([batch.y[batch.ptr[i+1]-2:batch.ptr[i+1],-1,:3] for i in range(B)], dim=0) + edge_index_cond = torch.tensor([[[2*i,2*i+1],[2*i+1,2*i]] for i in range(batch_cond.shape[0] // 2)], dtype=torch.long, device=self.device) + edge_index_cond = edge_index_cond.flatten(end_dim=1).T + edge_attr_cond = torch.ones(edge_index_cond.shape[1], 1, device=self.device) # (B, obs_horizon * obs_dim) obs_cond = obs_cond.flatten(start_dim=1) @@ -202,14 +213,18 @@ def validate(self, dataset=None, model_path="last.pt"): noise = noise.flatten(end_dim=1) # predict the noise residual - noise_pred, x = self.ema_noise_pred_net( + noise_pred, x = self.noise_pred_net( noisy_actions, batch.edge_index, batch.edge_attr, x_coord = batch.y[:,-1,:3], cond=obs_cond, timesteps=timesteps, - batch=batch.batch) + batch=batch.batch, + batch_cond=batch_cond, + edge_index_cond=edge_index_cond, + edge_attr_cond=edge_attr_cond, + x_coord_cond=x_coord_cond) # L2 loss loss = nn.functional.mse_loss(noise_pred, noise) @@ -276,6 +291,13 @@ def train(self, # observation as FiLM conditioning # (B, node, obs_horizon, obs_dim) obs_cond = nobs[:,:,3:] # only 6D rotation + # filter only 2 last nodes of each graph by batch.ptr + obs_cond = torch.cat([obs_cond[batch.ptr[i+1]-2:batch.ptr[i+1]] for i in range(B)], dim=0) + batch_cond = torch.cat([batch.batch[batch.ptr[i+1]-2:batch.ptr[i+1]] for i in range(B)], dim=0) + x_coord_cond = torch.cat([batch.y[batch.ptr[i+1]-2:batch.ptr[i+1],-1,:3] for i in range(B)], dim=0) + edge_index_cond = torch.tensor([[[2*i,2*i+1],[2*i+1,2*i]] for i in range(batch_cond.shape[0] // 2)], dtype=torch.long, device=self.device) + edge_index_cond = edge_index_cond.flatten(end_dim=1).T + edge_attr_cond = torch.ones(edge_index_cond.shape[1], 1, device=self.device) # (B, obs_horizon * obs_dim) obs_cond = obs_cond.flatten(start_dim=1) @@ -315,7 +337,11 @@ def train(self, x_coord = batch.y[:,-1,:3], cond=obs_cond, timesteps=timesteps, - batch=batch.batch) + batch=batch.batch, + batch_cond=batch_cond, + edge_index_cond=edge_index_cond, + edge_attr_cond=edge_attr_cond, + x_coord_cond=x_coord_cond) # L2 loss loss = nn.functional.mse_loss(noise_pred, noise) From f252ff0e693f8dea3d9fcf7fb134de26a4d4781c Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Thu, 4 Apr 2024 19:12:08 +0200 Subject: [PATCH 7/9] OSC environment wrapper running --- .../config/policy/graph_ddpm_policy.yaml | 4 +- imitation/config/task/lift_graph.yaml | 2 +- imitation/dataset/robomimic_graph_dataset.py | 55 ++++++++++++------- imitation/env/robomimic_graph_wrapper.py | 46 +++++++++++----- imitation/policy/graph_ddpm_policy.py | 2 +- 5 files changed, 71 insertions(+), 38 deletions(-) diff --git a/imitation/config/policy/graph_ddpm_policy.yaml b/imitation/config/policy/graph_ddpm_policy.yaml index 15606f0..38404ed 100644 --- a/imitation/config/policy/graph_ddpm_policy.yaml +++ b/imitation/config/policy/graph_ddpm_policy.yaml @@ -1,7 +1,7 @@ _target_: imitation.policy.graph_ddpm_policy.GraphConditionalDDPMPolicy obs_dim: 9 -action_dim: 9 -node_feature_dim: 1 # from [joint_val, node_flag] +action_dim: 1 #${eval:'1 if ${task.control_mode} == "OSC_POSE" else 9'} +node_feature_dim: 9 # from [joint_val, node_flag] num_edge_types: 2 # robot joints, object-robot pred_horizon: ${pred_horizon} obs_horizon: ${obs_horizon} diff --git a/imitation/config/task/lift_graph.yaml b/imitation/config/task/lift_graph.yaml index 44b3e92..74e83cb 100644 --- a/imitation/config/task/lift_graph.yaml +++ b/imitation/config/task/lift_graph.yaml @@ -5,7 +5,7 @@ dataset_path: &dataset_path ./data/lift/${task.dataset_type}/low_dim_v141.hdf5 max_steps: ${max_steps} -control_mode: "JOINT_POSITION" +control_mode: "OSC_POSE" obs_keys: &obs_keys ['robot0_eef_pos', 'object'] action_keys: &action_keys ['robot0_joint_pos'] diff --git a/imitation/dataset/robomimic_graph_dataset.py b/imitation/dataset/robomimic_graph_dataset.py index 0a34365..a8188b2 100644 --- a/imitation/dataset/robomimic_graph_dataset.py +++ b/imitation/dataset/robomimic_graph_dataset.py @@ -29,7 +29,7 @@ def __init__(self, control_mode="JOINT_VELOCITY", base_link_shift=[0.0, 0.0, 0.0]): self.control_mode : str = control_mode - self.node_feature_dim : int = node_feature_dim + self.node_feature_dim : int = 10 if control_mode == "OSC_POSE" else node_feature_dim self.action_keys : List = action_keys self.pred_horizon : int = pred_horizon self.obs_horizon : int = obs_horizon @@ -63,7 +63,7 @@ def __init__(self, self.constant_stats = { "y": torch.tensor([False, False, False, True, True, True, True, True, True]), # mask rotations for robot and object nodes - "x": torch.tensor([False, True]) # node type flag is constant + "x": torch.tensor([False, False, False, True, True, True, True, True, True, True]) # node type flag is constant } @@ -100,8 +100,11 @@ def _get_object_pos(self, data, t): return obj_state_tensor def _get_node_pos(self, data, t): - node_pos = calculate_panda_joints_positions([*data["robot0_joint_pos"][t], *data["robot0_gripper_qpos"][t]]) - node_pos[:,:3] += torch.tensor(self.BASE_LINK_SHIFT) + if self.control_mode == "OSC_POSE": + node_pos = torch.cat([torch.tensor(data["robot0_eef_pos"][t]), torch.tensor(data["robot0_eef_quat"][t])], dim=0).unsqueeze(0) + else: + node_pos = calculate_panda_joints_positions([*data["robot0_joint_pos"][t], *data["robot0_gripper_qpos"][t]]) + node_pos[:,:3] += torch.tensor(self.BASE_LINK_SHIFT) # use rotation transformer to convert quaternion to 6d rotation node_pos = torch.cat([node_pos[:,:3], self.rotation_transformer.forward(node_pos[:,3:])], dim=1) obj_pos_tensor = self._get_object_pos(data, t) @@ -116,13 +119,14 @@ def _get_node_feats(self, data, t_vals): T = len(t_vals) node_feats = [] if self.control_mode == "OSC_POSE": - node_feats = torch.cat([torch.tensor(data["robot0_eef_pos"][t_vals]), torch.tensor(data["robot0_eef_quat"][t_vals])], dim=0) - node_feats = node_feats.reshape(T, -1) # add dimension + node_feats = torch.cat([torch.tensor(data["robot0_eef_pos"][t_vals]), torch.tensor(data["robot0_eef_quat"][t_vals])], dim=1) + # use rotation transformer to convert quaternion to 6d rotation + node_feats = torch.cat([node_feats[:,:3], self.rotation_transformer.forward(node_feats[:,3:])], dim=1).unsqueeze(0) if self.control_mode == "JOINT_VELOCITY": - node_feats = torch.cat([torch.tensor(data[f"robot0_joint_vel"][t_vals]), torch.tensor(data["robot0_gripper_qvel"][t_vals])], dim=1).T.unsqueeze(2) + node_feats = torch.cat([torch.tensor(data["robot0_joint_vel"][t_vals]), torch.tensor(data["robot0_gripper_qvel"][t_vals])], dim=1).T.unsqueeze(2) elif self.control_mode == "JOINT_POSITION": # [node, node_feats] - node_feats = torch.cat([torch.tensor(data[f"robot0_joint_pos"][t_vals]), torch.tensor(data["robot0_gripper_qpos"][t_vals])], dim=1).T.unsqueeze(2) + node_feats = torch.cat([torch.tensor(data["robot0_joint_pos"][t_vals]), torch.tensor(data["robot0_gripper_qpos"][t_vals])], dim=1).T.unsqueeze(2) # add dimension for NODE_TYPE flag, which is 0 for robot and 1 for objects node_feats = torch.cat((node_feats, self.ROBOT_NODE_TYPE*torch.ones((node_feats.shape[0],node_feats.shape[1],1))), dim=2) @@ -163,7 +167,7 @@ def _get_edge_attrs(self, edge_index): return torch.tensor(edge_attrs, dtype=torch.long) @lru_cache(maxsize=None) - def _get_edge_index(self, num_nodes): + def _get_edge_index(self, num_nodes, control_mode): ''' Returns edge index for graph. - all robot nodes are connected to the previous robot node @@ -171,6 +175,9 @@ def _get_edge_index(self, num_nodes): ''' eef_idx = 8 edge_index = [] + if control_mode == "OSC_POSE": + eef_idx = 0 + return torch.tensor([[eef_idx, obj_idx] for obj_idx in range(eef_idx, num_nodes)]) for idx in range(eef_idx): edge_index.append([idx, idx+1]) @@ -204,7 +211,7 @@ def process(self): data_raw = self.dataset_root["data"][key]["obs"] node_feats = self._get_node_feats_horizon(data_raw, idx, self.pred_horizon) - edge_index = self._get_edge_index(node_feats.shape[0]) + edge_index = self._get_edge_index(node_feats.shape[0], self.control_mode) edge_attrs = self._get_edge_attrs(edge_index) y = self._get_y_horizon(data_raw, idx, self.obs_horizon) pos = self._get_node_pos(data_raw, idx + self.pred_horizon) @@ -313,15 +320,21 @@ def __init__(self, def _get_node_pos(self, data, t): node_pos = [] - for i in range(self.num_robots): - node_pos_robot = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"][t], *data[f"robot{i}_gripper_qpos"][t]]) - # rotate robot nodes - rotation_matrix = R.from_quat(self.BASE_LINK_ROTATION[i]) - node_pos_robot[:,:3] = torch.matmul(node_pos_robot[:,:3], torch.tensor(rotation_matrix.as_matrix())) - node_pos_robot[:,3:] = torch.tensor((R.from_quat(node_pos_robot[:,3:].detach().numpy()) * rotation_matrix).as_quat()) - # add base link shift - node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) - node_pos.append(node_pos_robot) + if self.control_mode == "OSC_POSE": + for i in range(self.num_robots): + node_pos_robot = torch.cat([torch.tensor(data[f"robot{i}_eef_pos"][t]), torch.tensor(data[f"robot{i}_eef_quat"][t])], dim=0) + node_pos.append(node_pos_robot) + else: + for i in range(self.num_robots): + node_pos_robot = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"][t], *data[f"robot{i}_gripper_qpos"][t]]) + # rotate robot nodes + rotation_matrix = R.from_quat(self.BASE_LINK_ROTATION[i]) + node_pos_robot[:,:3] = torch.matmul(node_pos_robot[:,:3], torch.tensor(rotation_matrix.as_matrix())) + node_pos_robot[:,3:] = torch.tensor((R.from_quat(node_pos_robot[:,3:].detach().numpy()) * rotation_matrix).as_quat()) + # add base link shift + node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) + node_pos.append(node_pos_robot) + node_pos = torch.cat(node_pos, dim=0) # use rotation transformer to convert quaternion to 6d rotation node_pos = torch.cat([node_pos[:,:3], self.rotation_transformer.forward(node_pos[:,3:])], dim=1) @@ -360,13 +373,15 @@ def _get_node_feats(self, data, t_vals): return node_feats @lru_cache(maxsize=None) - def _get_edge_index(self, num_nodes): + def _get_edge_index(self, num_nodes, control_mode): ''' Returns edge index for graph. - all robot nodes are connected to the previous robot node - all object nodes are connected to the last robot node (end-effector) ''' assert len(self.eef_idx) == self.num_robots + 1 + if control_mode == "OSC_POSE": + return torch.tensor([[eef_id, obj_idx] for eef_id in self.eef_idx for obj_idx in range(self.eef_idx[self.num_robots], num_nodes)]) edge_index = [[self.eef_idx[0], self.eef_idx[1] + 1]] # robot0 base link to robot1 base link edge_index += [[idx, idx+1] for id_robot in range(1, len(self.eef_idx)-1) for idx in range(self.eef_idx[id_robot-1], self.eef_idx[id_robot])] diff --git a/imitation/env/robomimic_graph_wrapper.py b/imitation/env/robomimic_graph_wrapper.py index 28132f1..68b018b 100644 --- a/imitation/env/robomimic_graph_wrapper.py +++ b/imitation/env/robomimic_graph_wrapper.py @@ -61,7 +61,7 @@ def __init__(self, ''' self.object_state_sizes = object_state_sizes self.object_state_keys = object_state_keys - self.node_feature_dim = node_feature_dim + self.node_feature_dim = 10 if control_mode == "OSC_POSE" else node_feature_dim self.control_mode = control_mode controller_config = load_controller_config(default_controller=self.control_mode) # override default controller config with user-specified values @@ -91,7 +91,7 @@ def __init__(self, self.observation_space = self.env.observation_space self.num_objects = len(object_state_keys) - self.NUM_GRAPH_NODES = 9 + self.num_objects # TODO add multi-robot support + self.NUM_GRAPH_NODES = 1 + self.num_objects if self.control_mode == "OSC_POSE" else 9 + self.num_objects self.BASE_LINK_SHIFT = base_link_shift self.BASE_LINK_ROTATION = base_link_rotation self.ROBOT_NODE_TYPE = 1 @@ -154,16 +154,19 @@ def _get_object_pos(self, data): def _get_node_pos(self, data): node_pos = [] for i in range(len(self.robots)): - node_pos_robot = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"], *data[f"robot{i}_gripper_qpos"]]) - rotation_matrix = R.from_quat(self.BASE_LINK_ROTATION[i]) - node_pos_robot[:,:3] = torch.matmul(node_pos_robot[:,:3], torch.tensor(rotation_matrix.as_matrix())) - node_pos_robot[:,3:] = torch.tensor((R.from_quat(node_pos_robot[:,3:].detach().numpy()) * rotation_matrix).as_quat()) - # add base link shift - node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) + if self.control_mode == "OSC_POSE": + node_pos_robot = torch.cat([torch.tensor(data[f"robot{i}_eef_pos"]), torch.tensor(data[f"robot{i}_eef_6d_rot"])], dim=0).unsqueeze(0) + else: + node_pos_robot = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"], *data[f"robot{i}_gripper_qpos"]]) + rotation_matrix = R.from_quat(self.BASE_LINK_ROTATION[i]) + node_pos_robot[:,:3] = torch.matmul(node_pos_robot[:,:3], torch.tensor(rotation_matrix.as_matrix())) + node_pos_robot[:,3:] = torch.tensor((R.from_quat(node_pos_robot[:,3:].detach().numpy()) * rotation_matrix).as_quat()) + # add base link shift + node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) + # use rotation transformer to convert quaternion to 6d rotation + node_pos_robot = torch.cat([node_pos_robot[:,:3], self.rotation_transformer.forward(node_pos_robot[:,3:])], dim=1) node_pos.append(node_pos_robot) node_pos = torch.cat(node_pos, dim=0) - # use rotation transformer to convert quaternion to 6d rotation - node_pos = torch.cat([node_pos[:,:3], self.rotation_transformer.forward(node_pos[:,3:])], dim=1) obj_pos_tensor = self._get_object_pos(data) node_pos = torch.cat((node_pos, obj_pos_tensor), dim=0) return node_pos @@ -177,7 +180,8 @@ def _get_node_feats(self, data): node_feats = [] for i in range(len(self.robots)): if self.control_mode == "OSC_POSE": - node_feats.append(torch.cat([torch.tensor(data[f"robot{i}_eef_pos"]), torch.tensor(data[f"robot{i}_eef_quat"])], dim=0).reshape(1, -1)) # add dimension + node_feats.append(torch.cat([torch.tensor(data[f"robot{i}_eef_pos"]), torch.tensor(data[f"robot{i}_eef_6d_rot"])], dim=0).unsqueeze(0)) + elif self.control_mode == "JOINT_VELOCITY": node_feats.append(torch.tensor([*data[f"robot{i}_joint_vel"], *data[f"robot{i}_gripper_qvel"]]).reshape(1,-1).T) elif self.control_mode == "JOINT_POSITION": @@ -192,7 +196,10 @@ def _get_edges(self): - all robot nodes are connected to the previous robot node - all object nodes are connected to the last robot node (end-effector) ''' - eef_idx = 8 + if self.control_mode == "OSC_POSE": + eef_idx = 0 + else: + eef_idx = 8 edge_index = [] edge_attrs = [] for idx in range(eef_idx): @@ -241,7 +248,7 @@ def _robosuite_obs_to_robomimic_graph(self, obs): f"robot{i}_joint_pos": robot_joint_pos, f"robot{i}_joint_vel": robot_joint_vel, f"robot{i}_eef_pos": eef_pose, - f"robot{i}_eef_quat": eef_6d, + f"robot{i}_eef_6d_rot": eef_6d, f"robot{i}_gripper_qpos": gripper_pose, f"robot{i}_gripper_qvel": gripper_vel }) @@ -249,7 +256,7 @@ def _robosuite_obs_to_robomimic_graph(self, obs): node_feats = torch.cat([node_feats, self._get_node_feats(robot_i_data)], dim=0) robot_i_data["object"] = obs[len(self.robots)*39:] - node_pos = self._get_node_pos(robot_i_data) + node_pos = self._get_node_pos(robot_i_data) # should have shape [2, 9] # add dimension for NODE_TYPE, which is 0 for robot and 1 for objects node_feats = torch.cat((node_feats, self.ROBOT_NODE_TYPE*torch.ones((node_feats.shape[0],1))), dim=1) @@ -276,6 +283,17 @@ def reset(self): def step(self, action): final_action = [] + if self.control_mode == "OSC_POSE": + for i in range(len(self.robots)): + j = i*9 + robot_joint_pos = action[j:j + 3] + robot_joint_6d_rot = action[j + 3:j + 9] + # convert 6d rotation to quaternion + robot_joint_quat = self.rotation_transformer.inverse(robot_joint_6d_rot) + final_action.extend([*robot_joint_pos, *robot_joint_quat]) + obs, reward, done, _, info = self.env.step(final_action) + return self._robosuite_obs_to_robomimic_graph(obs), reward, done, info + for i in range(len(self.robots)): ''' Robosuite's action space is composed of 7 joint velocities and 1 gripper velocity, while diff --git a/imitation/policy/graph_ddpm_policy.py b/imitation/policy/graph_ddpm_policy.py index 323629a..4fa6254 100644 --- a/imitation/policy/graph_ddpm_policy.py +++ b/imitation/policy/graph_ddpm_policy.py @@ -142,7 +142,7 @@ def get_action(self, obs_deque): naction = naction.detach().to('cpu') if self.use_normalization: action_pred = self.dataset.unnormalize_data(naction, stats_key='x').numpy() - action = action_pred[:9,:,0].T + action = action_pred[0,:,:] # (action_horizon, action_dim) return action From e923dd8f710e2450ed996186a9a446f791c6c51e Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Sun, 7 Apr 2024 19:25:39 +0200 Subject: [PATCH 8/9] Graph HeteroData --- imitation/dataset/robomimic_graph_dataset.py | 135 ++++++++++--------- imitation/env/robomimic_graph_wrapper.py | 14 +- imitation/policy/egnn_policy.py | 10 +- imitation/utils/generic.py | 2 +- 4 files changed, 82 insertions(+), 79 deletions(-) diff --git a/imitation/dataset/robomimic_graph_dataset.py b/imitation/dataset/robomimic_graph_dataset.py index d2f1fae..2901317 100644 --- a/imitation/dataset/robomimic_graph_dataset.py +++ b/imitation/dataset/robomimic_graph_dataset.py @@ -1,6 +1,6 @@ from typing import Callable, Optional -from torch_geometric.data import Dataset, Data, InMemoryDataset +from torch_geometric.data import Dataset, Data, InMemoryDataset, HeteroData import logging import h5py import os.path as osp @@ -30,9 +30,10 @@ def __init__(self, base_link_shift=[0.0, 0.0, 0.0], base_link_rotation=[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0]]): self.control_mode : str = control_mode - self.node_feature_dim : int = 10 if control_mode == "OSC_POSE" else node_feature_dim + self.obs_mode : str = "OSC_POSE" self.robots : List = robots self.num_robots : int = len(self.robots) + self.node_feature_dim : int = 8 if control_mode == "OSC_POSE" else node_feature_dim self.pred_horizon : int = pred_horizon self.obs_horizon : int = obs_horizon self.object_state_sizes : Dict = object_state_sizes # can be taken from https://github.com/ARISE-Initiative/robosuite/tree/master/robosuite/environments/manipulation @@ -81,14 +82,6 @@ def processed_file_names(self): names = [f"data_{i}.pt" for i in range(self.len())] return names - # @lru_cache(maxsize=None) - def _get_object_feats(self, num_objects, node_feature_dim, OBJECT_NODE_TYPE, T): # no associated joint values - # create tensor of same dimension return super()._get_node_feats(data, t) as node_feats - obj_state_tensor = torch.zeros((num_objects, T, node_feature_dim)) - # add dimension for NODE_TYPE, which is 0 for robot and 1 for objects - obj_state_tensor[:,:,-1] = OBJECT_NODE_TYPE - return obj_state_tensor - def _get_object_pos(self, data, t): obj_state_tensor = torch.zeros((self.num_objects, 9)) # 3 for position, 6 for 6D rotation @@ -105,7 +98,7 @@ def _get_object_pos(self, data, t): return obj_state_tensor - def _get_node_pos(self, data, t): + def _get_node_pos(self, data, t, modality="action"): node_pos = [] for i in range(self.num_robots): if self.control_mode == "OSC_POSE": @@ -123,11 +116,12 @@ def _get_node_pos(self, data, t): node_pos = torch.cat(node_pos, dim=0) # use rotation transformer to convert quaternion to 6d rotation node_pos = torch.cat([node_pos[:,:3], self.rotation_transformer.forward(node_pos[:,3:])], dim=1) - obj_pos_tensor = self._get_object_pos(data, t) - node_pos = torch.cat((node_pos, obj_pos_tensor), dim=0) + if modality == "observation": + obj_pos_tensor = self._get_object_pos(data, t) + node_pos = torch.cat((node_pos, obj_pos_tensor), dim=0) return node_pos - def _get_node_feats(self, data, t_vals): + def _get_node_actions(self, data, t_vals, modality="action"): ''' Calculate node features for time steps t_vals t_vals: list of time steps @@ -136,9 +130,7 @@ def _get_node_feats(self, data, t_vals): node_feats = [] if self.control_mode == "OSC_POSE": for i in range(self.num_robots): - node_feats = torch.cat([torch.tensor(data["robot0_eef_pos"][t_vals]), torch.tensor(data["robot0_eef_quat"][t_vals])], dim=1) - # use rotation transformer to convert quaternion to 6d rotation - node_feats = torch.cat([node_feats[:,:3], self.rotation_transformer.forward(node_feats[:,3:])], dim=1).unsqueeze(0) + node_feats.append(torch.cat([torch.tensor(data["actions"][t_vals])], dim=1).unsqueeze(0)) elif self.control_mode == "JOINT_POSITION": for i in range(self.num_robots): node_feats.append(torch.cat([ @@ -151,34 +143,32 @@ def _get_node_feats(self, data, t_vals): torch.tensor(data[f"robot{i}_gripper_qvel"][t_vals])], dim=1).T.unsqueeze(2)) node_feats = torch.cat(node_feats, dim=0) # [num_robots*num_joints, T, 1] - # add dimension for NODE_TYPE, which is 0 for robot and 1 for objects - node_feats = torch.cat((node_feats, self.ROBOT_NODE_TYPE*torch.ones((node_feats.shape[0],node_feats.shape[1],1))), dim=2) - - obj_state_tensor = self._get_object_feats(self.num_objects, self.node_feature_dim, self.OBJECT_NODE_TYPE, T) - - node_feats = torch.cat((node_feats, obj_state_tensor), dim=0) return node_feats - def _get_node_feats_horizon(self, data, idx, horizon): + def _get_node_actions_horizon(self, data, idx, horizon): ''' Calculate node features for self.obs_horizon time steps ''' node_feats = [] # calculate node features for timesteps idx to idx + horizon t_vals = list(range(idx, idx + horizon)) - node_feats = self._get_node_feats(data, t_vals) + node_feats = self._get_node_actions(data, t_vals) return node_feats - @lru_cache(maxsize=None) - def _get_edge_attrs(self, edge_index): + # @lru_cache(maxsize=None) + def _get_edge_attrs(self, edge_index, modality="action"): ''' Attribute edge types to edges - self.ROBOT_LINK_EDGE for edges between robot nodes - self.OBJECT_ROBOT_EDGE for edges between robot and object nodes ''' edge_attrs = [] - num_nodes = torch.max(edge_index) - for edge in edge_index.t(): + + if modality == "action": + return torch.ones(edge_index.shape[1], dtype=torch.long) * self.ROBOT_LINK_EDGE + + num_nodes = torch.max(edge_index).item() + for edge in edge_index.T: # num nodes - self.num_objects is the index of the last robot node if edge[0] <= num_nodes - self.num_objects and edge[1] <= num_nodes - self.num_objects: edge_attrs.append(self.ROBOT_LINK_EDGE) @@ -187,43 +177,47 @@ def _get_edge_attrs(self, edge_index): edge_attrs.append(self.OBJECT_ROBOT_EDGE) return torch.tensor(edge_attrs, dtype=torch.long) - @lru_cache(maxsize=None) - def _get_edge_index(self, num_nodes, control_mode): + # @lru_cache(maxsize=None) + def _get_edge_index(self, num_nodes, modality="action"): ''' Returns edge index for graph. - all robot nodes are connected to the previous robot node - all object nodes are connected to the last robot node (end-effector) ''' - assert len(self.eef_idx) == self.num_robots + 1 edge_index = [] - if control_mode == "OSC_POSE": - eef_idx = 0 - return torch.tensor([[eef_idx, obj_idx] for obj_idx in range(eef_idx, num_nodes)]) - for idx in range(eef_idx): - edge_index.append([idx, idx+1]) + graph_type = self.control_mode if modality == "action" else self.obs_mode + eef_idx = self.eef_idx + if graph_type == "OSC_POSE": # 1 node per robot + eef_idx = list(range(-1, self.num_robots)) + edge_index += [[0, 0]] + edge_index += [[eef_idx[robot_node], eef_idx[robot_node+1]] for robot_node in range(1, len(eef_idx)-1)] + else: # JOINT_POSITION or JOINT_VELOCITY + assert len(self.eef_idx) == self.num_robots + 1 + + if len(self.eef_idx) == 3: # 2 robots + edge_index = [[self.eef_idx[0]+ 1, self.eef_idx[1] + 1]] # robot0 base link to robot1 base link + for robot in range(self.num_robots): + # Connectivity of all robot nodes to the previous robot node + edge_index += [[idx, idx+1] for idx in range(self.eef_idx[robot]+ 1, self.eef_idx[robot+1])] + + if modality == "observation": + # Connectivity of all objects to all robot nodes + edge_index += [[node_idx, idx] for idx in range(eef_idx[-1] + 1, num_nodes) for node_idx in range(eef_idx[self.num_robots] + 1)] - if len(self.eef_idx) == 3: # 2 robots - edge_index = [[self.eef_idx[0]+ 1, self.eef_idx[1] + 1]] # robot0 base link to robot1 base link - for robot in range(self.num_robots): - # Connectivity of all robot nodes to the previous robot node - edge_index += [[idx, idx+1] for idx in range(self.eef_idx[robot]+ 1, self.eef_idx[robot+1])] - # Connectivity of all other nodes to all robot nodes - edge_index += [[node_idx, idx] for idx in range(self.eef_idx[-1] + 1, num_nodes) for node_idx in range(self.eef_idx[self.num_robots] + 1)] - # edge_index.append(torch.tensor([node_idx, idx]) for node_idx in range(self.eef_idx[self.num_robots] + 1)) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return edge_index - def _get_y_horizon(self, data, idx, horizon): + def _get_obs_horizon(self, data, idx, horizon, modality="observation"): ''' - Get y (observation) for time step t. Should contain only task-space joint positions. + Get observation node features for time step t. Should contain only task-space joint positions. ''' - y = [] + obs = [] for t in range(idx, idx - horizon,-1): if t < 0: - y.append(self._get_node_pos(data, 0)) # use fixed first observation for beginning of episode + obs.append(self._get_node_pos(data, 0, modality="observation")) # use fixed first observation for beginning of episode else: - y.append(self._get_node_pos(data, t)) - return torch.stack(y, dim=1) + obs.append(self._get_node_pos(data, t, modality="observation")) + return torch.stack(obs, dim=1) def process(self): @@ -234,21 +228,36 @@ def process(self): for idx in range(episode_length - self.pred_horizon): - data_raw = self.dataset_root["data"][key]["obs"] - node_feats = self._get_node_feats_horizon(data_raw, idx, self.pred_horizon) - edge_index = self._get_edge_index(node_feats.shape[0], self.control_mode) - edge_attrs = self._get_edge_attrs(edge_index) - y = self._get_y_horizon(data_raw, idx, self.obs_horizon) - pos = self._get_node_pos(data_raw, idx + self.pred_horizon) + data_raw = self.dataset_root["data"][key] + action_feats = self._get_node_actions_horizon(data_raw, idx, self.pred_horizon) + action_edge_index = self._get_edge_index(action_feats.shape[0], modality="action") + action_edge_attrs = self._get_edge_attrs(action_edge_index, modality="action") + action_pos = self._get_node_pos(data_raw["obs"], idx + self.pred_horizon, modality="action") + + obs_feats = self._get_obs_horizon(data_raw["obs"], idx, self.obs_horizon) + obs_pos = self._get_node_pos(data_raw["obs"], idx, modality="observation") + obs_edge_index = self._get_edge_index(obs_feats.shape[0], modality="observation") + obs_edge_attrs = self._get_edge_attrs(obs_edge_index, modality="observation") - data = Data( - x=node_feats, - edge_index=edge_index, - edge_attr=edge_attrs, - y=y, + action_graph = Data( + x=action_feats, + edge_index=action_edge_index, + edge_attr=action_edge_attrs, + y=action_pos, time=torch.tensor([idx], dtype=torch.long)/ episode_length, - pos=pos + pos=action_pos ) + obs_graph = Data( + x=obs_feats, + edge_index=obs_edge_index, + edge_attr=obs_edge_attrs, + time=torch.tensor([idx], dtype=torch.long)/ episode_length, + pos=obs_pos + ) + + data = HeteroData() + data["action"] = action_graph + data["observation"] = obs_graph torch.save(data, osp.join(self.processed_dir, f'data_{idx_global}.pt')) idx_global += 1 diff --git a/imitation/env/robomimic_graph_wrapper.py b/imitation/env/robomimic_graph_wrapper.py index f09b3d4..5d5f998 100644 --- a/imitation/env/robomimic_graph_wrapper.py +++ b/imitation/env/robomimic_graph_wrapper.py @@ -61,7 +61,7 @@ def __init__(self, ''' self.object_state_sizes = object_state_sizes self.object_state_keys = object_state_keys - self.node_feature_dim = 10 if control_mode == "OSC_POSE" else node_feature_dim + self.node_feature_dim = 8 if control_mode == "OSC_POSE" else node_feature_dim self.control_mode = control_mode controller_config = load_controller_config(default_controller=self.control_mode) # override default controller config with user-specified values @@ -186,7 +186,8 @@ def _get_node_feats(self, data): node_feats = [] for i in range(len(self.robots)): if self.control_mode == "OSC_POSE": - node_feats.append(torch.cat([torch.tensor(data[f"robot{i}_eef_pos"]), torch.tensor(data[f"robot{i}_eef_6d_rot"])], dim=0).unsqueeze(0)) + # node_feats.append(torch.cat([torch.tensor(data[f"robot{i}_eef_pos"]), torch.tensor(data[f"robot{i}_eef_6d_rot"])], dim=0).unsqueeze(0)) + node_feats.append(torch.zeros((1,7))) elif self.control_mode == "JOINT_VELOCITY": node_feats.append(torch.tensor([*data[f"robot{i}_joint_vel"], *data[f"robot{i}_gripper_qvel"]]).reshape(1,-1).T) elif self.control_mode == "JOINT_POSITION": @@ -301,14 +302,7 @@ def reset(self): def step(self, action): final_action = [] if self.control_mode == "OSC_POSE": - for i in range(len(self.robots)): - j = i*9 - robot_joint_pos = action[j:j + 3] - robot_joint_6d_rot = action[j + 3:j + 9] - # convert 6d rotation to quaternion - robot_joint_quat = self.rotation_transformer.inverse(robot_joint_6d_rot) - final_action.extend([*robot_joint_pos, *robot_joint_quat]) - obs, reward, done, _, info = self.env.step(final_action) + obs, reward, done, _, info = self.env.step(action) return self._robosuite_obs_to_robomimic_graph(obs), reward, done, info for i in range(len(self.robots)): diff --git a/imitation/policy/egnn_policy.py b/imitation/policy/egnn_policy.py index 8726d27..bce2baa 100644 --- a/imitation/policy/egnn_policy.py +++ b/imitation/policy/egnn_policy.py @@ -73,7 +73,7 @@ def get_action(self, obs_deque): x=y[:,-1,:3].to(self.device).float(), ) pred = pred.reshape(-1, self.pred_horizon, self.node_feature_dim) - return pred[:9,:,0].T.detach().cpu().numpy() # return joint values only + return pred[0,:,:7].detach().cpu().numpy() # return joint values only def validate(self, dataset, model_path): ''' @@ -118,7 +118,7 @@ def train(self, dataset, num_epochs, model_path, seed=0): ) loss_fn = nn.MSELoss() - optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) + optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=1e-6) # LR scheduler with warmup lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.8) # visualize data in batch @@ -140,7 +140,7 @@ def train(self, dataset, num_epochs, model_path, seed=0): x=nbatch.y[:,-1,:3].to(self.device).float(), ) pred = pred.reshape(-1, self.pred_horizon, self.node_feature_dim) - loss = loss_fn(pred, action) + loss = loss_fn(pred[:,:,0], action[:,:,0]) # loss_x = loss_fn(x, nbatch.pos[:,:3].to(self.device).float()) loss.backward() optimizer.step() @@ -149,11 +149,11 @@ def train(self, dataset, num_epochs, model_path, seed=0): pbar.set_postfix({"loss": loss.item()}) wandb.log({"loss": loss.item()}) - + # save model + torch.save(self.model.state_dict(), model_path) self.global_epoch += 1 wandb.log({"epoch": self.global_epoch, "loss": loss.item()}) - # save model torch.save(self.model.state_dict(), model_path) pbar.set_description(f"Epoch: {self.global_epoch}, Loss: {loss.item()}") diff --git a/imitation/utils/generic.py b/imitation/utils/generic.py index 365222c..52382f1 100644 --- a/imitation/utils/generic.py +++ b/imitation/utils/generic.py @@ -3,7 +3,7 @@ from scipy.spatial.transform import Rotation as R -from torch_robotics.torch_kinematics_tree.models.robots import DifferentiableFrankaPanda +from torch_kinematics_tree.models.robots import DifferentiableFrankaPanda def to_numpy(x): return x.detach().cpu().numpy() From 5f62a44cfb19dfe840b84c4eaadbd4348dcfc38b Mon Sep 17 00:00:00 2001 From: Caio Freitas Date: Tue, 9 Apr 2024 17:28:12 +0200 Subject: [PATCH 9/9] train ddpm policy with HeteroData --- .../config/policy/graph_ddpm_policy.yaml | 4 +- imitation/dataset/robomimic_graph_dataset.py | 36 +++++++++-------- imitation/model/graph_diffusion.py | 4 +- imitation/policy/graph_ddpm_policy.py | 39 +++++++++---------- 4 files changed, 42 insertions(+), 41 deletions(-) diff --git a/imitation/config/policy/graph_ddpm_policy.yaml b/imitation/config/policy/graph_ddpm_policy.yaml index 3b6e2b5..cec6c35 100644 --- a/imitation/config/policy/graph_ddpm_policy.yaml +++ b/imitation/config/policy/graph_ddpm_policy.yaml @@ -1,7 +1,7 @@ _target_: imitation.policy.graph_ddpm_policy.GraphConditionalDDPMPolicy obs_dim: ${task.obs_dim} action_dim: 1 #${eval:'1 if ${task.control_mode} == "OSC_POSE" else ${task.action_dim}'} -node_feature_dim: 9 # from [joint_val, node_flag] +node_feature_dim: 1 # from [joint_val, node_flag] num_edge_types: 2 # robot joints, object-robot pred_horizon: ${pred_horizon} obs_horizon: ${obs_horizon} @@ -22,4 +22,4 @@ ckpt_path: ./weights/diffusion_graph_policy_${task.task_name}_${task.dataset_typ lr: 5e-5 batch_size: 32 noise_addition_std: 0.1 -use_normalization: True \ No newline at end of file +use_normalization: False \ No newline at end of file diff --git a/imitation/dataset/robomimic_graph_dataset.py b/imitation/dataset/robomimic_graph_dataset.py index 2901317..7254d43 100644 --- a/imitation/dataset/robomimic_graph_dataset.py +++ b/imitation/dataset/robomimic_graph_dataset.py @@ -26,11 +26,12 @@ def __init__(self, pred_horizon=1, obs_horizon=1, node_feature_dim = 2, # joint value and node type flag - control_mode="JOINT_VELOCITY", + control_mode="JOINT_POSITION", + obs_mode="JOINT_POSITION", base_link_shift=[0.0, 0.0, 0.0], base_link_rotation=[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0]]): self.control_mode : str = control_mode - self.obs_mode : str = "OSC_POSE" + self.obs_mode : str = obs_mode self.robots : List = robots self.num_robots : int = len(self.robots) self.node_feature_dim : int = 8 if control_mode == "OSC_POSE" else node_feature_dim @@ -65,8 +66,7 @@ def __init__(self, super().__init__(root=self._processed_dir, transform=None, pre_transform=None, pre_filter=None, log=True) self.stats = {} - self.stats["y"] = self.get_data_stats("y") - self.stats["x"] = self.get_data_stats("x") + self.stats = self.get_data_stats() self.constant_stats = { "y": torch.tensor([False, False, False, True, True, True, True, True, True]), # mask rotations for robot and object nodes @@ -134,13 +134,13 @@ def _get_node_actions(self, data, t_vals, modality="action"): elif self.control_mode == "JOINT_POSITION": for i in range(self.num_robots): node_feats.append(torch.cat([ - torch.tensor(data[f"robot{i}_joint_pos"][t_vals]), - torch.tensor(data[f"robot{i}_gripper_qpos"][t_vals])], dim=1).T.unsqueeze(2)) + torch.tensor(data["obs"][f"robot{i}_joint_pos"][t_vals]), + torch.tensor(data["obs"][f"robot{i}_gripper_qpos"][t_vals])], dim=1).T.unsqueeze(2)) elif self.control_mode == "JOINT_VELOCITY": for i in range(self.num_robots): node_feats.append(torch.cat([ - torch.tensor(data[f"robot{i}_joint_vel"][t_vals]), - torch.tensor(data[f"robot{i}_gripper_qvel"][t_vals])], dim=1).T.unsqueeze(2)) + torch.tensor(data["obs"][f"robot{i}_joint_vel"][t_vals]), + torch.tensor(data["obs"][f"robot{i}_gripper_qvel"][t_vals])], dim=1).T.unsqueeze(2)) node_feats = torch.cat(node_feats, dim=0) # [num_robots*num_joints, T, 1] return node_feats @@ -273,19 +273,21 @@ def get(self, idx): data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt')) return data - def get_data_stats(self, key): + def get_data_stats(self): ''' Returns min and max of data Used for normalizing data ''' - data = [] - for idx in range(self.len()): - data.append(torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))[key]) - data = torch.cat(data, dim=1) - return { - "min": torch.min(data, dim=1).values, - "max": torch.max(data, dim=1).values - } + stats = {} + for key in "action", "observation": + data = [] + for idx in range(self.len()): + data.append(torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))[key].x) + data = torch.cat(data, dim=1) + stats[key] = { + "min": torch.min(data, dim=1).values, + "max": torch.max(data, dim=1).values + } def normalize_data(self, data, stats_key, batch_size=1): # avoid division by zero by skipping normalization diff --git a/imitation/model/graph_diffusion.py b/imitation/model/graph_diffusion.py index 5dd9b94..fb9dc20 100644 --- a/imitation/model/graph_diffusion.py +++ b/imitation/model/graph_diffusion.py @@ -360,6 +360,7 @@ def forward(self, x, edge_index, edge_attr, x_coord, cond, timesteps, batch=None # make sure x and edge_attr are of type float, for the MLPs x = x.float().to(self.device).flatten(start_dim=1) edge_attr = edge_attr.float().to(self.device).unsqueeze(-1) # add channel dimension + edge_attr_cond = edge_attr_cond.float().to(self.device).unsqueeze(-1) # add channel dimension edge_index = edge_index.to(self.device) cond = cond.float().to(self.device) x_coord = x_coord.float().to(self.device) @@ -375,7 +376,8 @@ def forward(self, x, edge_index, edge_attr, x_coord, cond, timesteps, batch=None assert x.shape[0] == x_coord.shape[0], "x and x_coord must have the same length" edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=x.shape[0], fill_value=self.FILL_VALUE) - + edge_index_cond, edge_attr_cond = add_self_loops(edge_index_cond, edge_attr_cond, num_nodes=x.shape[0], fill_value=self.FILL_VALUE) + h_v = self.node_embedding(x) h_v = torch.cat([h_v, timesteps_embed], dim=-1) diff --git a/imitation/policy/graph_ddpm_policy.py b/imitation/policy/graph_ddpm_policy.py index d3230db..3ee710e 100644 --- a/imitation/policy/graph_ddpm_policy.py +++ b/imitation/policy/graph_ddpm_policy.py @@ -10,6 +10,7 @@ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.training_utils import EMAModel import wandb +from torch_geometric.data import Batch from imitation.policy.base_policy import BasePolicy @@ -100,6 +101,7 @@ def MOCK_get_graph_from_obs(self): # for testing purposes, remove before merge self.playback_count += 7 log.info(f"Playing back observation {self.playback_count}") return obs_cond, playback_graph + def get_action(self, obs_deque): B = 1 # action shape is (B, Ta, Da), observations (B, To, Do) # transform deques to numpy arrays @@ -282,7 +284,7 @@ def train(self, lr_scheduler = get_scheduler( name='cosine', optimizer=optimizer, - num_warmup_steps=500, + num_warmup_steps=5, num_training_steps=len(dataloader) * num_epochs ) @@ -298,19 +300,16 @@ def train(self, nobs = self.dataset.normalize_data(batch.y, stats_key='y', batch_size=batch.num_graphs).to(self.device) # normalize action naction = self.dataset.normalize_data(batch.x, stats_key='x', batch_size=batch.num_graphs).to(self.device) - naction = naction[:,:,:1] + else: + nobs_batch = Batch().from_data_list(batch["observation"]) + naction_batch = Batch().from_data_list(batch["action"]) + naction = naction_batch.x[:,:,:1] B = batch.num_graphs # observation as FiLM conditioning # (B, node, obs_horizon, obs_dim) - obs_cond = nobs[:,:,3:] # only 6D rotation - # filter only 2 last nodes of each graph by batch.ptr - obs_cond = torch.cat([obs_cond[batch.ptr[i+1]-2:batch.ptr[i+1]] for i in range(B)], dim=0) - batch_cond = torch.cat([batch.batch[batch.ptr[i+1]-2:batch.ptr[i+1]] for i in range(B)], dim=0) - x_coord_cond = torch.cat([batch.y[batch.ptr[i+1]-2:batch.ptr[i+1],-1,:3] for i in range(B)], dim=0) - edge_index_cond = torch.tensor([[[2*i,2*i+1],[2*i+1,2*i]] for i in range(batch_cond.shape[0] // 2)], dtype=torch.long, device=self.device) - edge_index_cond = edge_index_cond.flatten(end_dim=1).T - edge_attr_cond = torch.ones(edge_index_cond.shape[1], 1, device=self.device) + obs_cond = nobs_batch.x[:,:,3:] # only 6D rotation + # (B, obs_horizon * obs_dim) obs_cond = obs_cond.flatten(start_dim=1) @@ -324,13 +323,11 @@ def train(self, # (this is the forward diffusion process) # split naction into (B, N_nodes, pred_horizon, node_feature_dim), selecting the items from each batch.batch - naction = torch.cat([naction[batch.batch == i].unsqueeze(0) for i in batch.batch.unique()], dim=0) + naction = torch.cat([naction[naction_batch.batch == i].unsqueeze(0) for i in naction_batch.batch.unique()], dim=0) # add noise to first action instead of sampling from Gaussian noise = (1 - self.noise_addition_std) * naction[:,:,0,:].unsqueeze(2).repeat(1,1,naction.shape[2],1).float() + self.noise_addition_std * torch.randn(naction.shape, device=self.device, dtype=torch.float32) - noise = torch.randn(naction.shape, device=self.device, dtype=torch.float32) - noisy_actions = self.noise_scheduler.add_noise( naction, noise, timesteps) @@ -346,16 +343,16 @@ def train(self, # predict the noise residual noise_pred, x = self.noise_pred_net( noisy_actions, - batch.edge_index, - batch.edge_attr, - x_coord = batch.y[:,-1,:3], + naction_batch.edge_index, + naction_batch.edge_attr, + x_coord = naction_batch.pos[:,:3], cond=obs_cond, timesteps=timesteps, - batch=batch.batch, - batch_cond=batch_cond, - edge_index_cond=edge_index_cond, - edge_attr_cond=edge_attr_cond, - x_coord_cond=x_coord_cond) + batch=naction_batch.batch, + batch_cond=nobs_batch.batch, + edge_index_cond=nobs_batch.edge_index, + edge_attr_cond=nobs_batch.edge_attr, + x_coord_cond=nobs_batch.pos[:,:3]) # L2 loss loss = nn.functional.mse_loss(noise_pred, noise)