diff --git a/imitation/config/policy/graph_ddpm_policy.yaml b/imitation/config/policy/graph_ddpm_policy.yaml index 381a270..cec6c35 100644 --- a/imitation/config/policy/graph_ddpm_policy.yaml +++ b/imitation/config/policy/graph_ddpm_policy.yaml @@ -1,8 +1,6 @@ _target_: imitation.policy.graph_ddpm_policy.GraphConditionalDDPMPolicy - obs_dim: ${task.obs_dim} -action_dim: ${task.action_dim} - +action_dim: 1 #${eval:'1 if ${task.control_mode} == "OSC_POSE" else ${task.action_dim}'} node_feature_dim: 1 # from [joint_val, node_flag] num_edge_types: 2 # robot joints, object-robot pred_horizon: ${pred_horizon} @@ -24,4 +22,4 @@ ckpt_path: ./weights/diffusion_graph_policy_${task.task_name}_${task.dataset_typ lr: 5e-5 batch_size: 32 noise_addition_std: 0.1 -use_normalization: True \ No newline at end of file +use_normalization: False \ No newline at end of file diff --git a/imitation/config/task/lift_graph.yaml b/imitation/config/task/lift_graph.yaml index 910690b..78ebac9 100644 --- a/imitation/config/task/lift_graph.yaml +++ b/imitation/config/task/lift_graph.yaml @@ -5,7 +5,7 @@ dataset_path: &dataset_path ./data/lift/${task.dataset_type}/low_dim_v141.hdf5 max_steps: ${max_steps} -control_mode: "JOINT_POSITION" +control_mode: "OSC_POSE" obs_dim: 9 action_dim: 9 diff --git a/imitation/config/task/square_graph.yaml b/imitation/config/task/square_graph.yaml index bfc1ea5..4a0a308 100644 --- a/imitation/config/task/square_graph.yaml +++ b/imitation/config/task/square_graph.yaml @@ -57,4 +57,4 @@ dataset: object_state_sizes: *object_state_sizes object_state_keys: *object_state_keys control_mode: ${task.control_mode} - base_link_shift: [[-0.56, 0, 0.912]] \ No newline at end of file + base_link_shift: [[-0.56, 0, 0.912]] diff --git a/imitation/config/task/transport_graph.yaml b/imitation/config/task/transport_graph.yaml index 5609463..dff447d 100644 --- a/imitation/config/task/transport_graph.yaml +++ b/imitation/config/task/transport_graph.yaml @@ -47,6 +47,13 @@ base_link_rotation: - [0.707107, 0, 0, 0.707107] - [0.707107, 0, 0, -0.707107] +base_link_shift: + - [0.0, -0.81, 0.912] + - [0.0, 0.81, 0.912] +base_link_rotation: + - [0.707107, 0, 0, 0.707107] + - [0.707107, 0, 0, -0.707107] + env_runner: _target_: imitation.env_runner.robomimic_lowdim_runner.RobomimicEnvRunner output_dir: ${output_dir} @@ -79,6 +86,7 @@ dataset: obs_horizon: ${obs_horizon} object_state_sizes: *object_state_sizes object_state_keys: *object_state_keys + robots: *robots control_mode: ${task.control_mode} base_link_shift: ${task.base_link_shift} base_link_rotation: ${task.base_link_rotation} \ No newline at end of file diff --git a/imitation/dataset/robomimic_graph_dataset.py b/imitation/dataset/robomimic_graph_dataset.py index cd19bb1..7254d43 100644 --- a/imitation/dataset/robomimic_graph_dataset.py +++ b/imitation/dataset/robomimic_graph_dataset.py @@ -1,6 +1,6 @@ from typing import Callable, Optional -from torch_geometric.data import Dataset, Data, InMemoryDataset +from torch_geometric.data import Dataset, Data, InMemoryDataset, HeteroData import logging import h5py import os.path as osp @@ -26,13 +26,15 @@ def __init__(self, pred_horizon=1, obs_horizon=1, node_feature_dim = 2, # joint value and node type flag - control_mode="JOINT_VELOCITY", + control_mode="JOINT_POSITION", + obs_mode="JOINT_POSITION", base_link_shift=[0.0, 0.0, 0.0], base_link_rotation=[[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0]]): self.control_mode : str = control_mode - self.node_feature_dim : int = node_feature_dim + self.obs_mode : str = obs_mode self.robots : List = robots self.num_robots : int = len(self.robots) + self.node_feature_dim : int = 8 if control_mode == "OSC_POSE" else node_feature_dim self.pred_horizon : int = pred_horizon self.obs_horizon : int = obs_horizon self.object_state_sizes : Dict = object_state_sizes # can be taken from https://github.com/ARISE-Initiative/robosuite/tree/master/robosuite/environments/manipulation @@ -64,12 +66,11 @@ def __init__(self, super().__init__(root=self._processed_dir, transform=None, pre_transform=None, pre_filter=None, log=True) self.stats = {} - self.stats["y"] = self.get_data_stats("y") - self.stats["x"] = self.get_data_stats("x") + self.stats = self.get_data_stats() self.constant_stats = { "y": torch.tensor([False, False, False, True, True, True, True, True, True]), # mask rotations for robot and object nodes - "x": torch.tensor([False, True]) # node type flag is constant + "x": torch.tensor([False, False, False, True, True, True, True, True, True, True]) # node type flag is constant } @@ -81,14 +82,6 @@ def processed_file_names(self): names = [f"data_{i}.pt" for i in range(self.len())] return names - # @lru_cache(maxsize=None) - def _get_object_feats(self, num_objects, node_feature_dim, OBJECT_NODE_TYPE, T): # no associated joint values - # create tensor of same dimension return super()._get_node_feats(data, t) as node_feats - obj_state_tensor = torch.zeros((num_objects, T, node_feature_dim)) - # add dimension for NODE_TYPE, which is 0 for robot and 1 for objects - obj_state_tensor[:,:,-1] = OBJECT_NODE_TYPE - return obj_state_tensor - def _get_object_pos(self, data, t): obj_state_tensor = torch.zeros((self.num_objects, 9)) # 3 for position, 6 for 6D rotation @@ -105,25 +98,30 @@ def _get_object_pos(self, data, t): return obj_state_tensor - def _get_node_pos(self, data, t): + def _get_node_pos(self, data, t, modality="action"): node_pos = [] for i in range(self.num_robots): - node_pos_robot = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"][t], *data[f"robot{i}_gripper_qpos"][t]]) - # rotate robot nodes - rotation_matrix = R.from_quat(self.BASE_LINK_ROTATION[i]) - node_pos_robot[:,:3] = torch.matmul(node_pos_robot[:,:3], torch.tensor(rotation_matrix.as_matrix())) - node_pos_robot[:,3:] = torch.tensor((R.from_quat(node_pos_robot[:,3:].detach().numpy()) * rotation_matrix).as_quat()) - # add base link shift - node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) + if self.control_mode == "OSC_POSE": + node_pos_robot = torch.cat([torch.tensor(data["robot0_eef_pos"][t]), torch.tensor(data["robot0_eef_quat"][t])], dim=0).unsqueeze(0) + else: + node_pos_robot = calculate_panda_joints_positions([*data["robot0_joint_pos"][t], *data["robot0_gripper_qpos"][t]]) + node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT) + # rotate robot nodes + rotation_matrix = R.from_quat(self.BASE_LINK_ROTATION[i]) + node_pos_robot[:,:3] = torch.matmul(node_pos_robot[:,:3], torch.tensor(rotation_matrix.as_matrix())) + node_pos_robot[:,3:] = torch.tensor((R.from_quat(node_pos_robot[:,3:].detach().numpy()) * rotation_matrix).as_quat()) + # add base link shift + node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) node_pos.append(node_pos_robot) node_pos = torch.cat(node_pos, dim=0) # use rotation transformer to convert quaternion to 6d rotation node_pos = torch.cat([node_pos[:,:3], self.rotation_transformer.forward(node_pos[:,3:])], dim=1) - obj_pos_tensor = self._get_object_pos(data, t) - node_pos = torch.cat((node_pos, obj_pos_tensor), dim=0) + if modality == "observation": + obj_pos_tensor = self._get_object_pos(data, t) + node_pos = torch.cat((node_pos, obj_pos_tensor), dim=0) return node_pos - def _get_node_feats(self, data, t_vals): + def _get_node_actions(self, data, t_vals, modality="action"): ''' Calculate node features for time steps t_vals t_vals: list of time steps @@ -132,47 +130,45 @@ def _get_node_feats(self, data, t_vals): node_feats = [] if self.control_mode == "OSC_POSE": for i in range(self.num_robots): - node_feats.append(torch.cat([torch.tensor(data["robot0_eef_pos"][t_vals]), torch.tensor(data["robot0_eef_quat"][t_vals])], dim=0)) + node_feats.append(torch.cat([torch.tensor(data["actions"][t_vals])], dim=1).unsqueeze(0)) elif self.control_mode == "JOINT_POSITION": for i in range(self.num_robots): node_feats.append(torch.cat([ - torch.tensor(data[f"robot{i}_joint_pos"][t_vals]), - torch.tensor(data[f"robot{i}_gripper_qpos"][t_vals])], dim=1).T.unsqueeze(2)) + torch.tensor(data["obs"][f"robot{i}_joint_pos"][t_vals]), + torch.tensor(data["obs"][f"robot{i}_gripper_qpos"][t_vals])], dim=1).T.unsqueeze(2)) elif self.control_mode == "JOINT_VELOCITY": for i in range(self.num_robots): node_feats.append(torch.cat([ - torch.tensor(data[f"robot{i}_joint_vel"][t_vals]), - torch.tensor(data[f"robot{i}_gripper_qvel"][t_vals])], dim=1).T.unsqueeze(2)) + torch.tensor(data["obs"][f"robot{i}_joint_vel"][t_vals]), + torch.tensor(data["obs"][f"robot{i}_gripper_qvel"][t_vals])], dim=1).T.unsqueeze(2)) node_feats = torch.cat(node_feats, dim=0) # [num_robots*num_joints, T, 1] - # add dimension for NODE_TYPE, which is 0 for robot and 1 for objects - node_feats = torch.cat((node_feats, self.ROBOT_NODE_TYPE*torch.ones((node_feats.shape[0],node_feats.shape[1],1))), dim=2) - - obj_state_tensor = self._get_object_feats(self.num_objects, self.node_feature_dim, self.OBJECT_NODE_TYPE, T) - - node_feats = torch.cat((node_feats, obj_state_tensor), dim=0) return node_feats - def _get_node_feats_horizon(self, data, idx, horizon): + def _get_node_actions_horizon(self, data, idx, horizon): ''' Calculate node features for self.obs_horizon time steps ''' node_feats = [] # calculate node features for timesteps idx to idx + horizon t_vals = list(range(idx, idx + horizon)) - node_feats = self._get_node_feats(data, t_vals) + node_feats = self._get_node_actions(data, t_vals) return node_feats - @lru_cache(maxsize=None) - def _get_edge_attrs(self, edge_index): + # @lru_cache(maxsize=None) + def _get_edge_attrs(self, edge_index, modality="action"): ''' Attribute edge types to edges - self.ROBOT_LINK_EDGE for edges between robot nodes - self.OBJECT_ROBOT_EDGE for edges between robot and object nodes ''' edge_attrs = [] - num_nodes = torch.max(edge_index) - for edge in edge_index.t(): + + if modality == "action": + return torch.ones(edge_index.shape[1], dtype=torch.long) * self.ROBOT_LINK_EDGE + + num_nodes = torch.max(edge_index).item() + for edge in edge_index.T: # num nodes - self.num_objects is the index of the last robot node if edge[0] <= num_nodes - self.num_objects and edge[1] <= num_nodes - self.num_objects: edge_attrs.append(self.ROBOT_LINK_EDGE) @@ -181,37 +177,47 @@ def _get_edge_attrs(self, edge_index): edge_attrs.append(self.OBJECT_ROBOT_EDGE) return torch.tensor(edge_attrs, dtype=torch.long) - @lru_cache(maxsize=None) - def _get_edge_index(self, num_nodes): + # @lru_cache(maxsize=None) + def _get_edge_index(self, num_nodes, modality="action"): ''' Returns edge index for graph. - all robot nodes are connected to the previous robot node - all object nodes are connected to the last robot node (end-effector) ''' - assert len(self.eef_idx) == self.num_robots + 1 edge_index = [] - if len(self.eef_idx) == 3: # 2 robots - edge_index = [[self.eef_idx[0]+ 1, self.eef_idx[1] + 1]] # robot0 base link to robot1 base link - for robot in range(self.num_robots): - # Connectivity of all robot nodes to the previous robot node - edge_index += [[idx, idx+1] for idx in range(self.eef_idx[robot]+ 1, self.eef_idx[robot+1])] - # Connectivity of all other nodes to all robot nodes - edge_index += [[node_idx, idx] for idx in range(self.eef_idx[-1] + 1, num_nodes) for node_idx in range(self.eef_idx[self.num_robots] + 1)] - # edge_index.append(torch.tensor([node_idx, idx]) for node_idx in range(self.eef_idx[self.num_robots] + 1)) + graph_type = self.control_mode if modality == "action" else self.obs_mode + eef_idx = self.eef_idx + if graph_type == "OSC_POSE": # 1 node per robot + eef_idx = list(range(-1, self.num_robots)) + edge_index += [[0, 0]] + edge_index += [[eef_idx[robot_node], eef_idx[robot_node+1]] for robot_node in range(1, len(eef_idx)-1)] + else: # JOINT_POSITION or JOINT_VELOCITY + assert len(self.eef_idx) == self.num_robots + 1 + + if len(self.eef_idx) == 3: # 2 robots + edge_index = [[self.eef_idx[0]+ 1, self.eef_idx[1] + 1]] # robot0 base link to robot1 base link + for robot in range(self.num_robots): + # Connectivity of all robot nodes to the previous robot node + edge_index += [[idx, idx+1] for idx in range(self.eef_idx[robot]+ 1, self.eef_idx[robot+1])] + + if modality == "observation": + # Connectivity of all objects to all robot nodes + edge_index += [[node_idx, idx] for idx in range(eef_idx[-1] + 1, num_nodes) for node_idx in range(eef_idx[self.num_robots] + 1)] + edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() return edge_index - def _get_y_horizon(self, data, idx, horizon): + def _get_obs_horizon(self, data, idx, horizon, modality="observation"): ''' - Get y (observation) for time step t. Should contain only task-space joint positions. + Get observation node features for time step t. Should contain only task-space joint positions. ''' - y = [] + obs = [] for t in range(idx, idx - horizon,-1): if t < 0: - y.append(self._get_node_pos(data, 0)) # use fixed first observation for beginning of episode + obs.append(self._get_node_pos(data, 0, modality="observation")) # use fixed first observation for beginning of episode else: - y.append(self._get_node_pos(data, t)) - return torch.stack(y, dim=1) + obs.append(self._get_node_pos(data, t, modality="observation")) + return torch.stack(obs, dim=1) def process(self): @@ -222,22 +228,37 @@ def process(self): for idx in range(episode_length - self.pred_horizon): - data_raw = self.dataset_root["data"][key]["obs"] - node_feats = self._get_node_feats_horizon(data_raw, idx, self.pred_horizon) - edge_index = self._get_edge_index(node_feats.shape[0]) - edge_attrs = self._get_edge_attrs(edge_index) - y = self._get_y_horizon(data_raw, idx, self.obs_horizon) - pos = self._get_node_pos(data_raw, idx + self.pred_horizon) + data_raw = self.dataset_root["data"][key] + action_feats = self._get_node_actions_horizon(data_raw, idx, self.pred_horizon) + action_edge_index = self._get_edge_index(action_feats.shape[0], modality="action") + action_edge_attrs = self._get_edge_attrs(action_edge_index, modality="action") + action_pos = self._get_node_pos(data_raw["obs"], idx + self.pred_horizon, modality="action") + + obs_feats = self._get_obs_horizon(data_raw["obs"], idx, self.obs_horizon) + obs_pos = self._get_node_pos(data_raw["obs"], idx, modality="observation") + obs_edge_index = self._get_edge_index(obs_feats.shape[0], modality="observation") + obs_edge_attrs = self._get_edge_attrs(obs_edge_index, modality="observation") - data = Data( - x=node_feats, - edge_index=edge_index, - edge_attr=edge_attrs, - y=y, + action_graph = Data( + x=action_feats, + edge_index=action_edge_index, + edge_attr=action_edge_attrs, + y=action_pos, time=torch.tensor([idx], dtype=torch.long)/ episode_length, - pos=pos + pos=action_pos + ) + obs_graph = Data( + x=obs_feats, + edge_index=obs_edge_index, + edge_attr=obs_edge_attrs, + time=torch.tensor([idx], dtype=torch.long)/ episode_length, + pos=obs_pos ) + data = HeteroData() + data["action"] = action_graph + data["observation"] = obs_graph + torch.save(data, osp.join(self.processed_dir, f'data_{idx_global}.pt')) idx_global += 1 @@ -252,19 +273,21 @@ def get(self, idx): data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt')) return data - def get_data_stats(self, key): + def get_data_stats(self): ''' Returns min and max of data Used for normalizing data ''' - data = [] - for idx in range(self.len()): - data.append(torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))[key]) - data = torch.cat(data, dim=1) - return { - "min": torch.min(data, dim=1).values, - "max": torch.max(data, dim=1).values - } + stats = {} + for key in "action", "observation": + data = [] + for idx in range(self.len()): + data.append(torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))[key].x) + data = torch.cat(data, dim=1) + stats[key] = { + "min": torch.min(data, dim=1).values, + "max": torch.max(data, dim=1).values + } def normalize_data(self, data, stats_key, batch_size=1): # avoid division by zero by skipping normalization diff --git a/imitation/env/robomimic_graph_wrapper.py b/imitation/env/robomimic_graph_wrapper.py index f039939..5d5f998 100644 --- a/imitation/env/robomimic_graph_wrapper.py +++ b/imitation/env/robomimic_graph_wrapper.py @@ -61,7 +61,7 @@ def __init__(self, ''' self.object_state_sizes = object_state_sizes self.object_state_keys = object_state_keys - self.node_feature_dim = node_feature_dim + self.node_feature_dim = 8 if control_mode == "OSC_POSE" else node_feature_dim self.control_mode = control_mode controller_config = load_controller_config(default_controller=self.control_mode) # override default controller config with user-specified values @@ -92,7 +92,7 @@ def __init__(self, self.observation_space = self.env.observation_space self.num_objects = len(object_state_keys) - self.NUM_GRAPH_NODES = self.num_robots*9 + self.num_objects # TODO add multi-robot support + self.NUM_GRAPH_NODES = 1 + self.num_objects if self.control_mode == "OSC_POSE" else self.num_robots*9 + self.num_objects self.BASE_LINK_SHIFT = base_link_shift self.BASE_LINK_ROTATION = base_link_rotation self.ROBOT_NODE_TYPE = 1 @@ -160,16 +160,19 @@ def _get_object_pos(self, data): def _get_node_pos(self, data): node_pos = [] for i in range(self.num_robots): - node_pos_robot = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"], *data[f"robot{i}_gripper_qpos"]]) - rotation_matrix = R.from_quat(self.BASE_LINK_ROTATION[i]) - node_pos_robot[:,:3] = torch.matmul(node_pos_robot[:,:3], torch.tensor(rotation_matrix.as_matrix())) - node_pos_robot[:,3:] = torch.tensor((R.from_quat(node_pos_robot[:,3:].detach().numpy()) * rotation_matrix).as_quat()) - # add base link shift - node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) + if self.control_mode == "OSC_POSE": + node_pos_robot = torch.cat([torch.tensor(data[f"robot{i}_eef_pos"]), torch.tensor(data[f"robot{i}_eef_6d_rot"])], dim=0).unsqueeze(0) + else: + node_pos_robot = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"], *data[f"robot{i}_gripper_qpos"]]) + rotation_matrix = R.from_quat(self.BASE_LINK_ROTATION[i]) + node_pos_robot[:,:3] = torch.matmul(node_pos_robot[:,:3], torch.tensor(rotation_matrix.as_matrix())) + node_pos_robot[:,3:] = torch.tensor((R.from_quat(node_pos_robot[:,3:].detach().numpy()) * rotation_matrix).as_quat()) + # add base link shift + node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i]) + # use rotation transformer to convert quaternion to 6d rotation + node_pos_robot = torch.cat([node_pos_robot[:,:3], self.rotation_transformer.forward(node_pos_robot[:,3:])], dim=1) node_pos.append(node_pos_robot) node_pos = torch.cat(node_pos, dim=0) - # use rotation transformer to convert quaternion to 6d rotation - node_pos = torch.cat([node_pos[:,:3], self.rotation_transformer.forward(node_pos[:,3:])], dim=1) obj_pos_tensor = self._get_object_pos(data) node_pos = torch.cat((node_pos, obj_pos_tensor), dim=0) return node_pos @@ -181,9 +184,10 @@ def _get_node_feats(self, data): Returns node features from data ''' node_feats = [] - for i in range(self.num_robots): + for i in range(len(self.robots)): if self.control_mode == "OSC_POSE": - node_feats.append(torch.cat([torch.tensor(data[f"robot{i}_eef_pos"]), torch.tensor(data[f"robot{i}_eef_quat"])], dim=0).reshape(1, -1)) # add dimension + # node_feats.append(torch.cat([torch.tensor(data[f"robot{i}_eef_pos"]), torch.tensor(data[f"robot{i}_eef_6d_rot"])], dim=0).unsqueeze(0)) + node_feats.append(torch.zeros((1,7))) elif self.control_mode == "JOINT_VELOCITY": node_feats.append(torch.tensor([*data[f"robot{i}_joint_vel"], *data[f"robot{i}_gripper_qvel"]]).reshape(1,-1).T) elif self.control_mode == "JOINT_POSITION": @@ -260,7 +264,7 @@ def _robosuite_obs_to_robomimic_graph(self, obs): f"robot{i}_joint_pos": robot_joint_pos, f"robot{i}_joint_vel": robot_joint_vel, f"robot{i}_eef_pos": eef_pose, - f"robot{i}_eef_quat": eef_6d, + f"robot{i}_eef_6d_rot": eef_6d, f"robot{i}_gripper_qpos": gripper_pose, f"robot{i}_gripper_qvel": gripper_vel }) @@ -297,7 +301,11 @@ def reset(self): def step(self, action): final_action = [] - for i in range(self.num_robots): + if self.control_mode == "OSC_POSE": + obs, reward, done, _, info = self.env.step(action) + return self._robosuite_obs_to_robomimic_graph(obs), reward, done, info + + for i in range(len(self.robots)): ''' Robosuite's action space is composed of 7 joint velocities and 1 gripper velocity, while in the robomimic datasets, it's composed of 7 joint velocities and 2 gripper velocities (for each "finger"). diff --git a/imitation/model/graph_diffusion.py b/imitation/model/graph_diffusion.py index cf4b076..fb9dc20 100644 --- a/imitation/model/graph_diffusion.py +++ b/imitation/model/graph_diffusion.py @@ -356,15 +356,18 @@ def positionalencoding(self, lengths): return pes - def forward(self, x, edge_index, edge_attr, x_coord, cond, timesteps, batch=None): + def forward(self, x, edge_index, edge_attr, x_coord, cond, timesteps, batch=None, batch_cond=None, edge_index_cond=None, edge_attr_cond=None, x_coord_cond=None): # make sure x and edge_attr are of type float, for the MLPs x = x.float().to(self.device).flatten(start_dim=1) edge_attr = edge_attr.float().to(self.device).unsqueeze(-1) # add channel dimension + edge_attr_cond = edge_attr_cond.float().to(self.device).unsqueeze(-1) # add channel dimension edge_index = edge_index.to(self.device) cond = cond.float().to(self.device) x_coord = x_coord.float().to(self.device) + x_coord_cond = x_coord_cond.float().to(self.device) timesteps = timesteps.to(self.device) batch = batch.long().to(self.device) + batch_cond = batch_cond.long().to(self.device) batch_size = batch[-1] + 1 timesteps_embed = self.diffusion_step_encoder(self.pe[timesteps]) @@ -373,7 +376,8 @@ def forward(self, x, edge_index, edge_attr, x_coord, cond, timesteps, batch=None assert x.shape[0] == x_coord.shape[0], "x and x_coord must have the same length" edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=x.shape[0], fill_value=self.FILL_VALUE) - + edge_index_cond, edge_attr_cond = add_self_loops(edge_index_cond, edge_attr_cond, num_nodes=x.shape[0], fill_value=self.FILL_VALUE) + h_v = self.node_embedding(x) h_v = torch.cat([h_v, timesteps_embed], dim=-1) @@ -381,7 +385,7 @@ def forward(self, x, edge_index, edge_attr, x_coord, cond, timesteps, batch=None h_e = self.edge_embedding(edge_attr.reshape(-1, 1)) # FiLM generator - embed = self.cond_encoder(cond, edge_index, x_coord, edge_attr, batch=batch) + embed = self.cond_encoder(cond, edge_index_cond, x_coord_cond, edge_attr_cond, batch=batch_cond) embed = embed.reshape(self.num_layers, batch_size, 2, (self.hidden_dim + self.diffusion_step_embed_dim)) scales = embed[:,:,0,...] biases = embed[:,:,1,...] diff --git a/imitation/policy/egnn_policy.py b/imitation/policy/egnn_policy.py index 8726d27..bce2baa 100644 --- a/imitation/policy/egnn_policy.py +++ b/imitation/policy/egnn_policy.py @@ -73,7 +73,7 @@ def get_action(self, obs_deque): x=y[:,-1,:3].to(self.device).float(), ) pred = pred.reshape(-1, self.pred_horizon, self.node_feature_dim) - return pred[:9,:,0].T.detach().cpu().numpy() # return joint values only + return pred[0,:,:7].detach().cpu().numpy() # return joint values only def validate(self, dataset, model_path): ''' @@ -118,7 +118,7 @@ def train(self, dataset, num_epochs, model_path, seed=0): ) loss_fn = nn.MSELoss() - optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) + optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=1e-6) # LR scheduler with warmup lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.8) # visualize data in batch @@ -140,7 +140,7 @@ def train(self, dataset, num_epochs, model_path, seed=0): x=nbatch.y[:,-1,:3].to(self.device).float(), ) pred = pred.reshape(-1, self.pred_horizon, self.node_feature_dim) - loss = loss_fn(pred, action) + loss = loss_fn(pred[:,:,0], action[:,:,0]) # loss_x = loss_fn(x, nbatch.pos[:,:3].to(self.device).float()) loss.backward() optimizer.step() @@ -149,11 +149,11 @@ def train(self, dataset, num_epochs, model_path, seed=0): pbar.set_postfix({"loss": loss.item()}) wandb.log({"loss": loss.item()}) - + # save model + torch.save(self.model.state_dict(), model_path) self.global_epoch += 1 wandb.log({"epoch": self.global_epoch, "loss": loss.item()}) - # save model torch.save(self.model.state_dict(), model_path) pbar.set_description(f"Epoch: {self.global_epoch}, Loss: {loss.item()}") diff --git a/imitation/policy/graph_ddpm_policy.py b/imitation/policy/graph_ddpm_policy.py index 5dae334..3ee710e 100644 --- a/imitation/policy/graph_ddpm_policy.py +++ b/imitation/policy/graph_ddpm_policy.py @@ -10,6 +10,7 @@ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.training_utils import EMAModel import wandb +from torch_geometric.data import Batch from imitation.policy.base_policy import BasePolicy @@ -100,6 +101,7 @@ def MOCK_get_graph_from_obs(self): # for testing purposes, remove before merge self.playback_count += 7 log.info(f"Playing back observation {self.playback_count}") return obs_cond, playback_graph + def get_action(self, obs_deque): B = 1 # action shape is (B, Ta, Da), observations (B, To, Do) # transform deques to numpy arrays @@ -133,9 +135,13 @@ def get_action(self, obs_deque): edge_index = G_t.edge_index, edge_attr = G_t.edge_attr, x_coord = nobs[:,-1,:3], - cond = nobs[:,:,3:], + cond = nobs[-2:,:,3:], # only end-effector and object timesteps = torch.tensor([k], dtype=torch.long, device=self.device), - batch = torch.zeros(naction.shape[0], dtype=torch.long, device=self.device) + batch = torch.zeros(naction.shape[0], dtype=torch.long, device=self.device), + batch_cond = torch.zeros(nobs[-2:,:,3:].shape[0], dtype=torch.long, device=self.device), + edge_index_cond = torch.tensor([[0,1],[1,0]], dtype=torch.long, device=self.device), + edge_attr_cond = torch.ones(2, 1, device=self.device), + x_coord_cond = nobs[-2:,-1,3:] ) # inverse diffusion step (remove noise) @@ -182,7 +188,14 @@ def validate(self, dataset=None, model_path="last.pt"): # observation as FiLM conditioning # (B, node, obs_horizon, obs_dim) - obs_cond = nobs[:,:,3:] + obs_cond = nobs[:,:,3:] # only 6D rotation + # filter only 2 last nodes of each graph by batch.ptr + obs_cond = torch.cat([obs_cond[batch.ptr[i+1]-2:batch.ptr[i+1]] for i in range(B)], dim=0) + batch_cond = torch.cat([batch.batch[batch.ptr[i+1]-2:batch.ptr[i+1]] for i in range(B)], dim=0) + x_coord_cond = torch.cat([batch.y[batch.ptr[i+1]-2:batch.ptr[i+1],-1,:3] for i in range(B)], dim=0) + edge_index_cond = torch.tensor([[[2*i,2*i+1],[2*i+1,2*i]] for i in range(batch_cond.shape[0] // 2)], dtype=torch.long, device=self.device) + edge_index_cond = edge_index_cond.flatten(end_dim=1).T + edge_attr_cond = torch.ones(edge_index_cond.shape[1], 1, device=self.device) # (B, obs_horizon * obs_dim) obs_cond = obs_cond.flatten(start_dim=1) @@ -215,14 +228,18 @@ def validate(self, dataset=None, model_path="last.pt"): noise = noise.flatten(end_dim=1) # predict the noise residual - noise_pred, x = self.ema_noise_pred_net( + noise_pred, x = self.noise_pred_net( noisy_actions, batch.edge_index, batch.edge_attr, x_coord = batch.y[:,-1,:3], cond=obs_cond, timesteps=timesteps, - batch=batch.batch) + batch=batch.batch, + batch_cond=batch_cond, + edge_index_cond=edge_index_cond, + edge_attr_cond=edge_attr_cond, + x_coord_cond=x_coord_cond) # L2 loss loss = nn.functional.mse_loss(noise_pred, noise) @@ -267,7 +284,7 @@ def train(self, lr_scheduler = get_scheduler( name='cosine', optimizer=optimizer, - num_warmup_steps=500, + num_warmup_steps=5, num_training_steps=len(dataloader) * num_epochs ) @@ -283,12 +300,16 @@ def train(self, nobs = self.dataset.normalize_data(batch.y, stats_key='y', batch_size=batch.num_graphs).to(self.device) # normalize action naction = self.dataset.normalize_data(batch.x, stats_key='x', batch_size=batch.num_graphs).to(self.device) - naction = naction[:,:,:1] + else: + nobs_batch = Batch().from_data_list(batch["observation"]) + naction_batch = Batch().from_data_list(batch["action"]) + naction = naction_batch.x[:,:,:1] B = batch.num_graphs # observation as FiLM conditioning # (B, node, obs_horizon, obs_dim) - obs_cond = nobs[:,:,3:] # only 6D rotation + obs_cond = nobs_batch.x[:,:,3:] # only 6D rotation + # (B, obs_horizon * obs_dim) obs_cond = obs_cond.flatten(start_dim=1) @@ -302,13 +323,11 @@ def train(self, # (this is the forward diffusion process) # split naction into (B, N_nodes, pred_horizon, node_feature_dim), selecting the items from each batch.batch - naction = torch.cat([naction[batch.batch == i].unsqueeze(0) for i in batch.batch.unique()], dim=0) + naction = torch.cat([naction[naction_batch.batch == i].unsqueeze(0) for i in naction_batch.batch.unique()], dim=0) # add noise to first action instead of sampling from Gaussian noise = (1 - self.noise_addition_std) * naction[:,:,0,:].unsqueeze(2).repeat(1,1,naction.shape[2],1).float() + self.noise_addition_std * torch.randn(naction.shape, device=self.device, dtype=torch.float32) - noise = torch.randn(naction.shape, device=self.device, dtype=torch.float32) - noisy_actions = self.noise_scheduler.add_noise( naction, noise, timesteps) @@ -324,12 +343,16 @@ def train(self, # predict the noise residual noise_pred, x = self.noise_pred_net( noisy_actions, - batch.edge_index, - batch.edge_attr, - x_coord = batch.y[:,-1,:3], + naction_batch.edge_index, + naction_batch.edge_attr, + x_coord = naction_batch.pos[:,:3], cond=obs_cond, timesteps=timesteps, - batch=batch.batch) + batch=naction_batch.batch, + batch_cond=nobs_batch.batch, + edge_index_cond=nobs_batch.edge_index, + edge_attr_cond=nobs_batch.edge_attr, + x_coord_cond=nobs_batch.pos[:,:3]) # L2 loss loss = nn.functional.mse_loss(noise_pred, noise) diff --git a/imitation/utils/generic.py b/imitation/utils/generic.py index 365222c..52382f1 100644 --- a/imitation/utils/generic.py +++ b/imitation/utils/generic.py @@ -3,7 +3,7 @@ from scipy.spatial.transform import Rotation as R -from torch_robotics.torch_kinematics_tree.models.robots import DifferentiableFrankaPanda +from torch_kinematics_tree.models.robots import DifferentiableFrankaPanda def to_numpy(x): return x.detach().cpu().numpy()