Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

109 try changing observation andor actions only to eef on graph policy #113

Open
wants to merge 10 commits into
base: development
Choose a base branch
from
6 changes: 2 additions & 4 deletions imitation/config/policy/graph_ddpm_policy.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
_target_: imitation.policy.graph_ddpm_policy.GraphConditionalDDPMPolicy

obs_dim: ${task.obs_dim}
action_dim: ${task.action_dim}

action_dim: 1 #${eval:'1 if ${task.control_mode} == "OSC_POSE" else ${task.action_dim}'}
node_feature_dim: 1 # from [joint_val, node_flag]
num_edge_types: 2 # robot joints, object-robot
pred_horizon: ${pred_horizon}
Expand All @@ -24,4 +22,4 @@ ckpt_path: ./weights/diffusion_graph_policy_${task.task_name}_${task.dataset_typ
lr: 5e-5
batch_size: 32
noise_addition_std: 0.1
use_normalization: True
use_normalization: False
2 changes: 1 addition & 1 deletion imitation/config/task/lift_graph.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dataset_path: &dataset_path ./data/lift/${task.dataset_type}/low_dim_v141.hdf5

max_steps: ${max_steps}

control_mode: "JOINT_POSITION"
control_mode: "OSC_POSE"

obs_dim: 9
action_dim: 9
Expand Down
2 changes: 1 addition & 1 deletion imitation/config/task/square_graph.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ dataset:
object_state_sizes: *object_state_sizes
object_state_keys: *object_state_keys
control_mode: ${task.control_mode}
base_link_shift: [[-0.56, 0, 0.912]]
base_link_shift: [[-0.56, 0, 0.912]]
8 changes: 8 additions & 0 deletions imitation/config/task/transport_graph.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ base_link_rotation:
- [0.707107, 0, 0, 0.707107]
- [0.707107, 0, 0, -0.707107]

base_link_shift:
- [0.0, -0.81, 0.912]
- [0.0, 0.81, 0.912]
base_link_rotation:
- [0.707107, 0, 0, 0.707107]
- [0.707107, 0, 0, -0.707107]

env_runner:
_target_: imitation.env_runner.robomimic_lowdim_runner.RobomimicEnvRunner
output_dir: ${output_dir}
Expand Down Expand Up @@ -79,6 +86,7 @@ dataset:
obs_horizon: ${obs_horizon}
object_state_sizes: *object_state_sizes
object_state_keys: *object_state_keys
robots: *robots
control_mode: ${task.control_mode}
base_link_shift: ${task.base_link_shift}
base_link_rotation: ${task.base_link_rotation}
183 changes: 103 additions & 80 deletions imitation/dataset/robomimic_graph_dataset.py

Large diffs are not rendered by default.

36 changes: 22 additions & 14 deletions imitation/env/robomimic_graph_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self,
'''
self.object_state_sizes = object_state_sizes
self.object_state_keys = object_state_keys
self.node_feature_dim = node_feature_dim
self.node_feature_dim = 8 if control_mode == "OSC_POSE" else node_feature_dim
self.control_mode = control_mode
controller_config = load_controller_config(default_controller=self.control_mode)
# override default controller config with user-specified values
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(self,
self.observation_space = self.env.observation_space
self.num_objects = len(object_state_keys)

self.NUM_GRAPH_NODES = self.num_robots*9 + self.num_objects # TODO add multi-robot support
self.NUM_GRAPH_NODES = 1 + self.num_objects if self.control_mode == "OSC_POSE" else self.num_robots*9 + self.num_objects
self.BASE_LINK_SHIFT = base_link_shift
self.BASE_LINK_ROTATION = base_link_rotation
self.ROBOT_NODE_TYPE = 1
Expand Down Expand Up @@ -160,16 +160,19 @@ def _get_object_pos(self, data):
def _get_node_pos(self, data):
node_pos = []
for i in range(self.num_robots):
node_pos_robot = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"], *data[f"robot{i}_gripper_qpos"]])
rotation_matrix = R.from_quat(self.BASE_LINK_ROTATION[i])
node_pos_robot[:,:3] = torch.matmul(node_pos_robot[:,:3], torch.tensor(rotation_matrix.as_matrix()))
node_pos_robot[:,3:] = torch.tensor((R.from_quat(node_pos_robot[:,3:].detach().numpy()) * rotation_matrix).as_quat())
# add base link shift
node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i])
if self.control_mode == "OSC_POSE":
node_pos_robot = torch.cat([torch.tensor(data[f"robot{i}_eef_pos"]), torch.tensor(data[f"robot{i}_eef_6d_rot"])], dim=0).unsqueeze(0)
else:
node_pos_robot = calculate_panda_joints_positions([*data[f"robot{i}_joint_pos"], *data[f"robot{i}_gripper_qpos"]])
rotation_matrix = R.from_quat(self.BASE_LINK_ROTATION[i])
node_pos_robot[:,:3] = torch.matmul(node_pos_robot[:,:3], torch.tensor(rotation_matrix.as_matrix()))
node_pos_robot[:,3:] = torch.tensor((R.from_quat(node_pos_robot[:,3:].detach().numpy()) * rotation_matrix).as_quat())
# add base link shift
node_pos_robot[:,:3] += torch.tensor(self.BASE_LINK_SHIFT[i])
# use rotation transformer to convert quaternion to 6d rotation
node_pos_robot = torch.cat([node_pos_robot[:,:3], self.rotation_transformer.forward(node_pos_robot[:,3:])], dim=1)
node_pos.append(node_pos_robot)
node_pos = torch.cat(node_pos, dim=0)
# use rotation transformer to convert quaternion to 6d rotation
node_pos = torch.cat([node_pos[:,:3], self.rotation_transformer.forward(node_pos[:,3:])], dim=1)
obj_pos_tensor = self._get_object_pos(data)
node_pos = torch.cat((node_pos, obj_pos_tensor), dim=0)
return node_pos
Expand All @@ -181,9 +184,10 @@ def _get_node_feats(self, data):
Returns node features from data
'''
node_feats = []
for i in range(self.num_robots):
for i in range(len(self.robots)):
if self.control_mode == "OSC_POSE":
node_feats.append(torch.cat([torch.tensor(data[f"robot{i}_eef_pos"]), torch.tensor(data[f"robot{i}_eef_quat"])], dim=0).reshape(1, -1)) # add dimension
# node_feats.append(torch.cat([torch.tensor(data[f"robot{i}_eef_pos"]), torch.tensor(data[f"robot{i}_eef_6d_rot"])], dim=0).unsqueeze(0))
node_feats.append(torch.zeros((1,7)))
elif self.control_mode == "JOINT_VELOCITY":
node_feats.append(torch.tensor([*data[f"robot{i}_joint_vel"], *data[f"robot{i}_gripper_qvel"]]).reshape(1,-1).T)
elif self.control_mode == "JOINT_POSITION":
Expand Down Expand Up @@ -260,7 +264,7 @@ def _robosuite_obs_to_robomimic_graph(self, obs):
f"robot{i}_joint_pos": robot_joint_pos,
f"robot{i}_joint_vel": robot_joint_vel,
f"robot{i}_eef_pos": eef_pose,
f"robot{i}_eef_quat": eef_6d,
f"robot{i}_eef_6d_rot": eef_6d,
f"robot{i}_gripper_qpos": gripper_pose,
f"robot{i}_gripper_qvel": gripper_vel
})
Expand Down Expand Up @@ -297,7 +301,11 @@ def reset(self):

def step(self, action):
final_action = []
for i in range(self.num_robots):
if self.control_mode == "OSC_POSE":
obs, reward, done, _, info = self.env.step(action)
return self._robosuite_obs_to_robomimic_graph(obs), reward, done, info

for i in range(len(self.robots)):
'''
Robosuite's action space is composed of 7 joint velocities and 1 gripper velocity, while
in the robomimic datasets, it's composed of 7 joint velocities and 2 gripper velocities (for each "finger").
Expand Down
10 changes: 7 additions & 3 deletions imitation/model/graph_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,15 +356,18 @@ def positionalencoding(self, lengths):
return pes


def forward(self, x, edge_index, edge_attr, x_coord, cond, timesteps, batch=None):
def forward(self, x, edge_index, edge_attr, x_coord, cond, timesteps, batch=None, batch_cond=None, edge_index_cond=None, edge_attr_cond=None, x_coord_cond=None):
# make sure x and edge_attr are of type float, for the MLPs
x = x.float().to(self.device).flatten(start_dim=1)
edge_attr = edge_attr.float().to(self.device).unsqueeze(-1) # add channel dimension
edge_attr_cond = edge_attr_cond.float().to(self.device).unsqueeze(-1) # add channel dimension
edge_index = edge_index.to(self.device)
cond = cond.float().to(self.device)
x_coord = x_coord.float().to(self.device)
x_coord_cond = x_coord_cond.float().to(self.device)
timesteps = timesteps.to(self.device)
batch = batch.long().to(self.device)
batch_cond = batch_cond.long().to(self.device)
batch_size = batch[-1] + 1

timesteps_embed = self.diffusion_step_encoder(self.pe[timesteps])
Expand All @@ -373,15 +376,16 @@ def forward(self, x, edge_index, edge_attr, x_coord, cond, timesteps, batch=None
assert x.shape[0] == x_coord.shape[0], "x and x_coord must have the same length"

edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=x.shape[0], fill_value=self.FILL_VALUE)

edge_index_cond, edge_attr_cond = add_self_loops(edge_index_cond, edge_attr_cond, num_nodes=x.shape[0], fill_value=self.FILL_VALUE)

h_v = self.node_embedding(x)

h_v = torch.cat([h_v, timesteps_embed], dim=-1)

h_e = self.edge_embedding(edge_attr.reshape(-1, 1))

# FiLM generator
embed = self.cond_encoder(cond, edge_index, x_coord, edge_attr, batch=batch)
embed = self.cond_encoder(cond, edge_index_cond, x_coord_cond, edge_attr_cond, batch=batch_cond)
embed = embed.reshape(self.num_layers, batch_size, 2, (self.hidden_dim + self.diffusion_step_embed_dim))
scales = embed[:,:,0,...]
biases = embed[:,:,1,...]
Expand Down
10 changes: 5 additions & 5 deletions imitation/policy/egnn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_action(self, obs_deque):
x=y[:,-1,:3].to(self.device).float(),
)
pred = pred.reshape(-1, self.pred_horizon, self.node_feature_dim)
return pred[:9,:,0].T.detach().cpu().numpy() # return joint values only
return pred[0,:,:7].detach().cpu().numpy() # return joint values only

def validate(self, dataset, model_path):
'''
Expand Down Expand Up @@ -118,7 +118,7 @@ def train(self, dataset, num_epochs, model_path, seed=0):
)

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=1e-6)
# LR scheduler with warmup
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.8)
# visualize data in batch
Expand All @@ -140,7 +140,7 @@ def train(self, dataset, num_epochs, model_path, seed=0):
x=nbatch.y[:,-1,:3].to(self.device).float(),
)
pred = pred.reshape(-1, self.pred_horizon, self.node_feature_dim)
loss = loss_fn(pred, action)
loss = loss_fn(pred[:,:,0], action[:,:,0])
# loss_x = loss_fn(x, nbatch.pos[:,:3].to(self.device).float())
loss.backward()
optimizer.step()
Expand All @@ -149,11 +149,11 @@ def train(self, dataset, num_epochs, model_path, seed=0):

pbar.set_postfix({"loss": loss.item()})
wandb.log({"loss": loss.item()})

# save model
torch.save(self.model.state_dict(), model_path)
self.global_epoch += 1
wandb.log({"epoch": self.global_epoch, "loss": loss.item()})

# save model
torch.save(self.model.state_dict(), model_path)
pbar.set_description(f"Epoch: {self.global_epoch}, Loss: {loss.item()}")

Expand Down
53 changes: 38 additions & 15 deletions imitation/policy/graph_ddpm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
import wandb
from torch_geometric.data import Batch

from imitation.policy.base_policy import BasePolicy

Expand Down Expand Up @@ -100,6 +101,7 @@ def MOCK_get_graph_from_obs(self): # for testing purposes, remove before merge
self.playback_count += 7
log.info(f"Playing back observation {self.playback_count}")
return obs_cond, playback_graph

def get_action(self, obs_deque):
B = 1 # action shape is (B, Ta, Da), observations (B, To, Do)
# transform deques to numpy arrays
Expand Down Expand Up @@ -133,9 +135,13 @@ def get_action(self, obs_deque):
edge_index = G_t.edge_index,
edge_attr = G_t.edge_attr,
x_coord = nobs[:,-1,:3],
cond = nobs[:,:,3:],
cond = nobs[-2:,:,3:], # only end-effector and object
timesteps = torch.tensor([k], dtype=torch.long, device=self.device),
batch = torch.zeros(naction.shape[0], dtype=torch.long, device=self.device)
batch = torch.zeros(naction.shape[0], dtype=torch.long, device=self.device),
batch_cond = torch.zeros(nobs[-2:,:,3:].shape[0], dtype=torch.long, device=self.device),
edge_index_cond = torch.tensor([[0,1],[1,0]], dtype=torch.long, device=self.device),
edge_attr_cond = torch.ones(2, 1, device=self.device),
x_coord_cond = nobs[-2:,-1,3:]
)

# inverse diffusion step (remove noise)
Expand Down Expand Up @@ -182,7 +188,14 @@ def validate(self, dataset=None, model_path="last.pt"):

# observation as FiLM conditioning
# (B, node, obs_horizon, obs_dim)
obs_cond = nobs[:,:,3:]
obs_cond = nobs[:,:,3:] # only 6D rotation
# filter only 2 last nodes of each graph by batch.ptr
obs_cond = torch.cat([obs_cond[batch.ptr[i+1]-2:batch.ptr[i+1]] for i in range(B)], dim=0)
batch_cond = torch.cat([batch.batch[batch.ptr[i+1]-2:batch.ptr[i+1]] for i in range(B)], dim=0)
x_coord_cond = torch.cat([batch.y[batch.ptr[i+1]-2:batch.ptr[i+1],-1,:3] for i in range(B)], dim=0)
edge_index_cond = torch.tensor([[[2*i,2*i+1],[2*i+1,2*i]] for i in range(batch_cond.shape[0] // 2)], dtype=torch.long, device=self.device)
edge_index_cond = edge_index_cond.flatten(end_dim=1).T
edge_attr_cond = torch.ones(edge_index_cond.shape[1], 1, device=self.device)
# (B, obs_horizon * obs_dim)
obs_cond = obs_cond.flatten(start_dim=1)

Expand Down Expand Up @@ -215,14 +228,18 @@ def validate(self, dataset=None, model_path="last.pt"):
noise = noise.flatten(end_dim=1)

# predict the noise residual
noise_pred, x = self.ema_noise_pred_net(
noise_pred, x = self.noise_pred_net(
noisy_actions,
batch.edge_index,
batch.edge_attr,
x_coord = batch.y[:,-1,:3],
cond=obs_cond,
timesteps=timesteps,
batch=batch.batch)
batch=batch.batch,
batch_cond=batch_cond,
edge_index_cond=edge_index_cond,
edge_attr_cond=edge_attr_cond,
x_coord_cond=x_coord_cond)

# L2 loss
loss = nn.functional.mse_loss(noise_pred, noise)
Expand Down Expand Up @@ -267,7 +284,7 @@ def train(self,
lr_scheduler = get_scheduler(
name='cosine',
optimizer=optimizer,
num_warmup_steps=500,
num_warmup_steps=5,
num_training_steps=len(dataloader) * num_epochs
)

Expand All @@ -283,12 +300,16 @@ def train(self,
nobs = self.dataset.normalize_data(batch.y, stats_key='y', batch_size=batch.num_graphs).to(self.device)
# normalize action
naction = self.dataset.normalize_data(batch.x, stats_key='x', batch_size=batch.num_graphs).to(self.device)
naction = naction[:,:,:1]
else:
nobs_batch = Batch().from_data_list(batch["observation"])
naction_batch = Batch().from_data_list(batch["action"])
naction = naction_batch.x[:,:,:1]
B = batch.num_graphs

# observation as FiLM conditioning
# (B, node, obs_horizon, obs_dim)
obs_cond = nobs[:,:,3:] # only 6D rotation
obs_cond = nobs_batch.x[:,:,3:] # only 6D rotation

# (B, obs_horizon * obs_dim)
obs_cond = obs_cond.flatten(start_dim=1)

Expand All @@ -302,13 +323,11 @@ def train(self,
# (this is the forward diffusion process)
# split naction into (B, N_nodes, pred_horizon, node_feature_dim), selecting the items from each batch.batch

naction = torch.cat([naction[batch.batch == i].unsqueeze(0) for i in batch.batch.unique()], dim=0)
naction = torch.cat([naction[naction_batch.batch == i].unsqueeze(0) for i in naction_batch.batch.unique()], dim=0)

# add noise to first action instead of sampling from Gaussian
noise = (1 - self.noise_addition_std) * naction[:,:,0,:].unsqueeze(2).repeat(1,1,naction.shape[2],1).float() + self.noise_addition_std * torch.randn(naction.shape, device=self.device, dtype=torch.float32)

noise = torch.randn(naction.shape, device=self.device, dtype=torch.float32)

noisy_actions = self.noise_scheduler.add_noise(
naction, noise, timesteps)

Expand All @@ -324,12 +343,16 @@ def train(self,
# predict the noise residual
noise_pred, x = self.noise_pred_net(
noisy_actions,
batch.edge_index,
batch.edge_attr,
x_coord = batch.y[:,-1,:3],
naction_batch.edge_index,
naction_batch.edge_attr,
x_coord = naction_batch.pos[:,:3],
cond=obs_cond,
timesteps=timesteps,
batch=batch.batch)
batch=naction_batch.batch,
batch_cond=nobs_batch.batch,
edge_index_cond=nobs_batch.edge_index,
edge_attr_cond=nobs_batch.edge_attr,
x_coord_cond=nobs_batch.pos[:,:3])

# L2 loss
loss = nn.functional.mse_loss(noise_pred, noise)
Expand Down
2 changes: 1 addition & 1 deletion imitation/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from scipy.spatial.transform import Rotation as R


from torch_robotics.torch_kinematics_tree.models.robots import DifferentiableFrankaPanda
from torch_kinematics_tree.models.robots import DifferentiableFrankaPanda

def to_numpy(x):
return x.detach().cpu().numpy()
Expand Down