Skip to content

Commit

Permalink
Merge branch 'xpbd-torque-application' into 'main'
Browse files Browse the repository at this point in the history
Fix joint_act indexing in XPBD apply_joint_actions

See merge request omniverse/warp!549
  • Loading branch information
mmacklin committed Jun 6, 2024
2 parents 5cb4670 + ea1e81d commit 0d62e27
Showing 1 changed file with 16 additions and 22 deletions.
38 changes: 16 additions & 22 deletions warp/sim/integrator_xpbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,11 +972,9 @@ def apply_body_delta_velocities(


@wp.kernel
def apply_joint_torques(
def apply_joint_actions(
body_q: wp.array(dtype=wp.transform),
body_com: wp.array(dtype=wp.vec3),
joint_q_start: wp.array(dtype=int),
joint_qd_start: wp.array(dtype=int),
joint_type: wp.array(dtype=int),
joint_parent: wp.array(dtype=int),
joint_child: wp.array(dtype=int),
Expand Down Expand Up @@ -1028,8 +1026,6 @@ def apply_joint_torques(
# q_c = wp.transform_get_rotation(X_wc)

# joint properties (for 1D joints)
# q_start = joint_q_start[tid]
qd_start = joint_qd_start[tid]
axis_start = joint_axis_start[tid]
lin_axis_count = joint_axis_dim[tid, 0]
ang_axis_count = joint_axis_dim[tid, 1]
Expand All @@ -1043,14 +1039,14 @@ def apply_joint_torques(
mode = joint_axis_mode[axis_start]
if mode == wp.sim.JOINT_MODE_FORCE:
axis = joint_axis[axis_start]
act = joint_act[qd_start]
act = joint_act[axis_start]
a_p = wp.transform_vector(X_wp, axis)
t_total += act * a_p
elif type == wp.sim.JOINT_PRISMATIC:
mode = joint_axis_mode[axis_start]
if mode == wp.sim.JOINT_MODE_FORCE:
axis = joint_axis[axis_start]
act = joint_act[qd_start]
act = joint_act[axis_start]
a_p = wp.transform_vector(X_wp, axis)
f_total += act * a_p
elif type == wp.sim.JOINT_COMPOUND:
Expand All @@ -1075,13 +1071,13 @@ def apply_joint_torques(

if joint_axis_mode[axis_start + 0] == wp.sim.JOINT_MODE_FORCE:
axis_0 = joint_axis[axis_start + 0]
t_total += joint_act[qd_start + 0] * wp.transform_vector(X_wp, axis_0)
t_total += joint_act[axis_start + 0] * wp.transform_vector(X_wp, axis_0)
if joint_axis_mode[axis_start + 1] == wp.sim.JOINT_MODE_FORCE:
axis_1 = joint_axis[axis_start + 1]
t_total += joint_act[qd_start + 1] * wp.transform_vector(X_wp, axis_1)
t_total += joint_act[axis_start + 1] * wp.transform_vector(X_wp, axis_1)
if joint_axis_mode[axis_start + 2] == wp.sim.JOINT_MODE_FORCE:
axis_2 = joint_axis[axis_start + 2]
t_total += joint_act[qd_start + 2] * wp.transform_vector(X_wp, axis_2)
t_total += joint_act[axis_start + 2] * wp.transform_vector(X_wp, axis_2)

elif type == wp.sim.JOINT_UNIVERSAL:
# q_off = wp.transform_get_rotation(X_cj)
Expand All @@ -1107,10 +1103,10 @@ def apply_joint_torques(

if joint_axis_mode[axis_start + 0] == wp.sim.JOINT_MODE_FORCE:
axis_0 = joint_axis[axis_start + 0]
t_total += joint_act[qd_start + 0] * wp.transform_vector(X_wp, axis_0)
t_total += joint_act[axis_start + 0] * wp.transform_vector(X_wp, axis_0)
if joint_axis_mode[axis_start + 1] == wp.sim.JOINT_MODE_FORCE:
axis_1 = joint_axis[axis_start + 1]
t_total += joint_act[qd_start + 1] * wp.transform_vector(X_wp, axis_1)
t_total += joint_act[axis_start + 1] * wp.transform_vector(X_wp, axis_1)

elif type == wp.sim.JOINT_D6:
# unroll for loop to ensure joint actions remain differentiable
Expand All @@ -1119,43 +1115,43 @@ def apply_joint_torques(
if lin_axis_count > 0:
if joint_axis_mode[axis_start + 0] == wp.sim.JOINT_MODE_FORCE:
axis = joint_axis[axis_start + 0]
act = joint_act[qd_start + 0]
act = joint_act[axis_start + 0]
a_p = wp.transform_vector(X_wp, axis)
f_total += act * a_p
if lin_axis_count > 1:
if joint_axis_mode[axis_start + 1] == wp.sim.JOINT_MODE_FORCE:
axis = joint_axis[axis_start + 1]
act = joint_act[qd_start + 1]
act = joint_act[axis_start + 1]
a_p = wp.transform_vector(X_wp, axis)
f_total += act * a_p
if lin_axis_count > 2:
if joint_axis_mode[axis_start + 2] == wp.sim.JOINT_MODE_FORCE:
axis = joint_axis[axis_start + 2]
act = joint_act[qd_start + 2]
act = joint_act[axis_start + 2]
a_p = wp.transform_vector(X_wp, axis)
f_total += act * a_p

if ang_axis_count > 0:
if joint_axis_mode[axis_start + lin_axis_count + 0] == wp.sim.JOINT_MODE_FORCE:
axis = joint_axis[axis_start + lin_axis_count + 0]
act = joint_act[qd_start + lin_axis_count + 0]
act = joint_act[axis_start + lin_axis_count + 0]
a_p = wp.transform_vector(X_wp, axis)
t_total += act * a_p
if ang_axis_count > 1:
if joint_axis_mode[axis_start + lin_axis_count + 1] == wp.sim.JOINT_MODE_FORCE:
axis = joint_axis[axis_start + lin_axis_count + 1]
act = joint_act[qd_start + lin_axis_count + 1]
act = joint_act[axis_start + lin_axis_count + 1]
a_p = wp.transform_vector(X_wp, axis)
t_total += act * a_p
if ang_axis_count > 2:
if joint_axis_mode[axis_start + lin_axis_count + 2] == wp.sim.JOINT_MODE_FORCE:
axis = joint_axis[axis_start + lin_axis_count + 2]
act = joint_act[qd_start + lin_axis_count + 2]
act = joint_act[axis_start + lin_axis_count + 2]
a_p = wp.transform_vector(X_wp, axis)
t_total += act * a_p

else:
print("joint type not handled in apply_joint_torques")
print("joint type not handled in apply_joint_actions")

# write forces
if id_p >= 0:
Expand Down Expand Up @@ -2838,13 +2834,11 @@ def simulate(self, model: Model, state_in: State, state_out: State, dt: float, c

if model.joint_count:
wp.launch(
kernel=apply_joint_torques,
kernel=apply_joint_actions,
dim=model.joint_count,
inputs=[
state_in.body_q,
model.body_com,
model.joint_q_start,
model.joint_qd_start,
model.joint_type,
model.joint_parent,
model.joint_child,
Expand Down

0 comments on commit 0d62e27

Please sign in to comment.