diff --git a/genesis/assets/xml/cartpole.xml b/genesis/assets/xml/cartpole.xml
new file mode 100644
index 0000000000..d7aba84f55
--- /dev/null
+++ b/genesis/assets/xml/cartpole.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/genesis/assets/xml/grad/capsule.xml b/genesis/assets/xml/grad/capsule.xml
new file mode 100644
index 0000000000..21fb2cc712
--- /dev/null
+++ b/genesis/assets/xml/grad/capsule.xml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
diff --git a/genesis/assets/xml/grad/free.xml b/genesis/assets/xml/grad/free.xml
new file mode 100644
index 0000000000..b2be573454
--- /dev/null
+++ b/genesis/assets/xml/grad/free.xml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
diff --git a/genesis/assets/xml/grad/free_with_revolute.xml b/genesis/assets/xml/grad/free_with_revolute.xml
new file mode 100644
index 0000000000..e867e68abc
--- /dev/null
+++ b/genesis/assets/xml/grad/free_with_revolute.xml
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/genesis/assets/xml/grad/prismatic.xml b/genesis/assets/xml/grad/prismatic.xml
new file mode 100644
index 0000000000..d59d4e76ff
--- /dev/null
+++ b/genesis/assets/xml/grad/prismatic.xml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
diff --git a/genesis/assets/xml/grad/revolute.xml b/genesis/assets/xml/grad/revolute.xml
new file mode 100644
index 0000000000..8221d398a3
--- /dev/null
+++ b/genesis/assets/xml/grad/revolute.xml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
diff --git a/genesis/assets/xml/grad/revolute_chain3.xml b/genesis/assets/xml/grad/revolute_chain3.xml
new file mode 100644
index 0000000000..f4ca0bcfb5
--- /dev/null
+++ b/genesis/assets/xml/grad/revolute_chain3.xml
@@ -0,0 +1,19 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/genesis/assets/xml/grad/slider_limit.xml b/genesis/assets/xml/grad/slider_limit.xml
new file mode 100644
index 0000000000..1b1a33c142
--- /dev/null
+++ b/genesis/assets/xml/grad/slider_limit.xml
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
+
+
+
diff --git a/genesis/assets/xml/grad/spherical.xml b/genesis/assets/xml/grad/spherical.xml
new file mode 100644
index 0000000000..2496c44a80
--- /dev/null
+++ b/genesis/assets/xml/grad/spherical.xml
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+
+
+
diff --git a/genesis/assets/xml/hopper.xml b/genesis/assets/xml/hopper.xml
new file mode 100644
index 0000000000..63859b2b3f
--- /dev/null
+++ b/genesis/assets/xml/hopper.xml
@@ -0,0 +1,27 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py
index 0db2c18640..523e28d287 100644
--- a/genesis/engine/entities/rigid_entity/rigid_entity.py
+++ b/genesis/engine/entities/rigid_entity/rigid_entity.py
@@ -109,7 +109,7 @@ def __init__(
self._load_model()
# Initialize target variables and checkpoint
- self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity")
+ self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity", "control_dofs_force")
self._tgt = dict()
self._tgt_buffer = list()
self._ckpt = dict()
@@ -1156,6 +1156,8 @@ def process_input(self, in_backward=False):
self.set_quat(**data_kwargs)
case "set_dofs_velocity":
self.set_dofs_velocity(**data_kwargs)
+ case "control_dofs_force":
+ self.control_dofs_force(**data_kwargs)
case _:
gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}")
@@ -1188,6 +1190,15 @@ def process_input_grad(self):
data_kwargs["dofs_idx_local"],
data_kwargs["envs_idx"],
)
+
+ case "control_dofs_force":
+ force = data_kwargs.pop("force")
+ if force is not None and force.requires_grad:
+ force._backward_from_qd(
+ self.set_dofs_force_grad,
+ data_kwargs["dofs_idx_local"],
+ data_kwargs["envs_idx"],
+ )
case _:
gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}")
@@ -3559,6 +3570,11 @@ def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, velocity_grad):
dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, velocity_grad.data)
+ @gs.assert_built
+ def set_dofs_force_grad(self, dofs_idx_local, envs_idx, force_grad):
+ dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
+ self._solver.set_dofs_force_grad(dofs_idx, envs_idx, force_grad.data)
+
# ------------------------------------------------------------------------------------
# ----------------------------- DOF property setters ---------------------------------
# ------------------------------------------------------------------------------------
@@ -3592,6 +3608,7 @@ def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zer
# ------------------------------------------------------------------------------------
@gs.assert_built
+ @tracked
def control_dofs_force(self, force, dofs_idx_local=None, envs_idx=None):
"""
Control the entity's dofs' motor force. This is used for force/torque control.
diff --git a/genesis/engine/scene.py b/genesis/engine/scene.py
index 7ff57c47f4..65f55f0f66 100644
--- a/genesis/engine/scene.py
+++ b/genesis/engine/scene.py
@@ -977,13 +977,16 @@ def reset(self, state: SimState | None = None, envs_idx=None):
self._reset(state, envs_idx=envs_idx)
self._recorder_manager.reset(envs_idx)
- def _reset(self, state: SimState | None = None, *, envs_idx=None):
+ def _reset(self, state: SimState | None = None, *, envs_idx=None, keep_init: bool = False):
if self._is_built:
if state is None:
state = self._init_state
else:
assert isinstance(state, SimState), "state must be a SimState object"
- self._init_state = state
+ # `keep_init=True` restores the state without making it the new
+ # init, so a later bare `reset()` still rewinds to the true init.
+ if not keep_init:
+ self._init_state = state
self._sim.reset(state, envs_idx)
else:
self._init_state = self._get_state()
@@ -1004,6 +1007,49 @@ def _reset(self, state: SimState | None = None, *, envs_idx=None):
def _reset_grad(self):
self._backward_ready = True
+ @gs.assert_built
+ def backward(self, loss: torch.Tensor, *args, **kwargs):
+ """Differentiate `loss` and restore the terminal physics state.
+
+ Wraps the snapshot/backward/restore dance that differentiable rollouts
+ otherwise have to perform by hand. `scene._backward()` rewinds physics
+ state to step 0 as a side-effect of unrolling the adstack, so the safe
+ pattern is to snapshot the terminal state *before* backward and restore
+ it *after*:
+
+ snapshot = scene.get_state() # terminal state
+ loss.backward() # rewinds physics to step 0
+ scene.reset(snapshot) # restore + clear grads + re-arm
+
+ This method does exactly that, so callers can just write
+ `scene.backward(loss)`. Afterwards the scene sits at the terminal physics
+ state with grads cleared and forward/backward re-armed — ready to continue
+ the rollout or to be reset to a fresh init.
+
+ The registered initial state (`reset()` with no args) is left untouched.
+
+ Parameters
+ ----------
+ loss : torch.Tensor
+ Scalar loss to differentiate. Extra args/kwargs (e.g. `gradient`,
+ `retain_graph`) are forwarded to `torch.autograd.backward`.
+ """
+ # Snapshot the terminal state before backward rewinds physics to step 0.
+ snapshot = self.get_state()
+ # `scene._backward()` re-enters the torch graph from each step's queried
+ # states (`_backward_from_qd` -> `state.backward(retain_graph=True)`), so
+ # the graph must survive the initial autograd pass.
+ kwargs.setdefault("retain_graph", True)
+ # Functional `torch.autograd.backward` fills torch + queried-state grads
+ # WITHOUT triggering `gs.Tensor.backward`'s auto `scene._backward()`, so
+ # we drive the sim unroll explicitly below.
+ torch.autograd.backward(loss, *args, **kwargs)
+ self._backward()
+ # Restore to the terminal snapshot; `keep_init=True` preserves the real
+ # initial state so a later bare `reset()` still rewinds to it.
+ self._reset(snapshot, keep_init=True)
+ return snapshot
+
def _get_state(self):
return self._sim.get_state()
diff --git a/genesis/engine/solvers/kinematic_solver.py b/genesis/engine/solvers/kinematic_solver.py
index 92ae00be0b..1acf7fb84f 100644
--- a/genesis/engine/solvers/kinematic_solver.py
+++ b/genesis/engine/solvers/kinematic_solver.py
@@ -45,6 +45,7 @@
kernel_set_dofs_position,
kernel_set_dofs_velocity,
kernel_set_dofs_velocity_grad,
+ kernel_set_dofs_force_grad,
kernel_set_dofs_zero_velocity,
kernel_set_links_pos,
kernel_set_links_quat,
@@ -951,6 +952,14 @@ def set_dofs_velocity_grad(self, dofs_idx, envs_idx, velocity_grad):
velocity_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config
)
+ def set_dofs_force_grad(self, dofs_idx, envs_idx, force_grad):
+ force_grad_, dofs_idx, envs_idx = self._sanitize_io_variables(
+ force_grad, dofs_idx, self.n_dofs, "dofs_idx", envs_idx, skip_allocation=True
+ )
+ if self.n_envs == 0:
+ force_grad_ = force_grad_.unsqueeze(0)
+ kernel_set_dofs_force_grad(force_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config)
+
def set_dofs_position(self, position, dofs_idx=None, envs_idx=None):
position, dofs_idx, envs_idx = self._sanitize_io_variables(
position, dofs_idx, self.n_dofs, "dofs_idx", envs_idx, skip_allocation=True
diff --git a/genesis/engine/solvers/rigid/abd/accessor.py b/genesis/engine/solvers/rigid/abd/accessor.py
index 672e47ec7c..aab9ae7381 100644
--- a/genesis/engine/solvers/rigid/abd/accessor.py
+++ b/genesis/engine/solvers/rigid/abd/accessor.py
@@ -787,6 +787,20 @@ def kernel_set_dofs_zero_velocity(
dofs_state.vel[dofs_idx[i_d_], envs_idx[i_b_]] = 0.0
+@qd.kernel(fastcache=True)
+def kernel_set_dofs_force_grad(
+ force_grad: qd.types.ndarray(),
+ dofs_idx: qd.types.ndarray(),
+ envs_idx: qd.types.ndarray(),
+ dofs_state: array_class.DofsState,
+ static_rigid_sim_config: qd.template(),
+):
+ qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
+ for i_d_, i_b_ in qd.ndrange(dofs_idx.shape[0], envs_idx.shape[0]):
+ force_grad[i_b_, i_d_] = dofs_state.ctrl_force.grad[dofs_idx[i_d_], envs_idx[i_b_]]
+ dofs_state.ctrl_force.grad[dofs_idx[i_d_], envs_idx[i_b_]] = 0.0
+
+
@qd.kernel(fastcache=True)
def kernel_set_dofs_position(
position: qd.types.ndarray(),
diff --git a/genesis/engine/solvers/rigid/abd/diff.py b/genesis/engine/solvers/rigid/abd/diff.py
index 39b491011f..a872d65344 100644
--- a/genesis/engine/solvers/rigid/abd/diff.py
+++ b/genesis/engine/solvers/rigid/abd/diff.py
@@ -19,7 +19,7 @@
import genesis as gs
import genesis.utils.geom as gu
import genesis.utils.array_class as array_class
-from .forward_kinematics import func_update_cartesian_space
+from .forward_kinematics import func_update_cartesian_space, func_forward_velocity
@qd.func
@@ -181,6 +181,16 @@ def kernel_prepare_backward_substep(
force_update_fixed_geoms=False,
is_backward=True,
)
+ func_forward_velocity(
+ entities_info=entities_info,
+ links_info=links_info,
+ links_state=links_state,
+ joints_info=joints_info,
+ dofs_state=dofs_state,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ is_backward=True,
+ )
# FIXME: Parameter pruning for ndarray is buggy for now and requires match variable and arg names.
# Save results of [update_cartesian_space] to adjoint cache
@@ -232,17 +242,17 @@ def kernel_begin_backward_substep(
)
if not static_rigid_sim_config.enable_mujoco_compatibility:
- # FIXME: Parameter pruning for ndarray is buggy for now and requires match variable and arg names.
- # Save results of [update_cartesian_space] to adjoint cache
+ # Restore the cartesian space that was overwritten by
+ # post-integrate forward replay in the backward substep
func_copy_cartesian_space(
- dofs_state=dofs_state,
- links_state=links_state,
- joints_state=joints_state,
- geoms_state=geoms_state,
- dofs_state_adjoint_cache=dofs_state_adjoint_cache,
- links_state_adjoint_cache=links_state_adjoint_cache,
- joints_state_adjoint_cache=joints_state_adjoint_cache,
- geoms_state_adjoint_cache=geoms_state_adjoint_cache,
+ dofs_state=dofs_state_adjoint_cache,
+ links_state=links_state_adjoint_cache,
+ joints_state=joints_state_adjoint_cache,
+ geoms_state=geoms_state_adjoint_cache,
+ dofs_state_adjoint_cache=dofs_state,
+ links_state_adjoint_cache=links_state,
+ joints_state_adjoint_cache=joints_state,
+ geoms_state_adjoint_cache=geoms_state,
static_rigid_sim_config=static_rigid_sim_config,
)
@@ -347,6 +357,27 @@ def kernel_copy_acc(
dofs_state.acc[i_d, i_b] = rigid_adjoint_cache.dofs_acc[f, i_d, i_b]
+@qd.kernel(fastcache=True)
+def kernel_copy_next_to_curr_no_check(
+ dofs_state: array_class.DofsState,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: qd.template(),
+):
+ # Unguarded copy of `_next` slots to current. Used in the backward substep right before
+ # the forward replay so the BW kernels see the post-integrate qpos / vel.
+ n_qs = rigid_global_info.qpos.shape[0]
+ n_dofs = dofs_state.vel.shape[0]
+ _B = dofs_state.vel.shape[1]
+
+ qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_q, i_b in qd.ndrange(n_qs, _B):
+ rigid_global_info.qpos[i_q, i_b] = rigid_global_info.qpos_next[i_q, i_b]
+
+ qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
+ for i_d, i_b in qd.ndrange(n_dofs, _B):
+ dofs_state.vel[i_d, i_b] = dofs_state.vel_next[i_d, i_b]
+
+
@qd.func
def func_integrate_dq_entity(
dq,
diff --git a/genesis/engine/solvers/rigid/abd/forward_dynamics.py b/genesis/engine/solvers/rigid/abd/forward_dynamics.py
index b5d7535554..c2e8b30aa3 100644
--- a/genesis/engine/solvers/rigid/abd/forward_dynamics.py
+++ b/genesis/engine/solvers/rigid/abd/forward_dynamics.py
@@ -36,10 +36,7 @@ def update_qacc_from_qvel_delta(
dofs_state: array_class.DofsState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
- is_backward: qd.template(),
):
- BW = qd.static(is_backward)
-
n_dofs = dofs_state.ctrl_mode.shape[0]
_B = dofs_state.ctrl_mode.shape[1]
@@ -67,10 +64,7 @@ def update_qvel(
dofs_state: array_class.DofsState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
- is_backward: qd.template(),
):
- BW = qd.static(is_backward)
-
_B = dofs_state.vel.shape[1]
n_dofs = dofs_state.vel.shape[0]
@@ -124,7 +118,6 @@ def kernel_compute_mass_matrix(
dofs_info=dofs_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
- is_backward=False,
)
@@ -163,7 +156,6 @@ def func_forward_dynamics(
dofs_info=dofs_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
- is_backward=is_backward,
)
func_torque_and_passive_force(
entities_state=entities_state,
@@ -484,220 +476,135 @@ def func_factor_mass(
dofs_info: array_class.DofsInfo,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
- is_backward: qd.template(),
):
- BW = qd.static(is_backward)
-
- if qd.static(not BW):
- n_entities = entities_info.n_links.shape[0]
- _B = dofs_state.ctrl_mode.shape[1]
-
- if qd.static(
- not static_rigid_sim_config.enable_tiled_cholesky_mass_matrix or static_rigid_sim_config.backend == gs.cpu
- ):
- qd.loop_config(name="factor_mass", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
- for i_e, i_b in qd.ndrange(n_entities, _B):
- if rigid_global_info.mass_mat_mask[i_e, i_b]:
- entity_dof_start = entities_info.dof_start[i_e]
- entity_dof_end = entities_info.dof_end[i_e]
- n_dofs = entities_info.n_dofs[i_e]
-
- for i_d in range(entity_dof_start, entity_dof_end):
- for j_d in range(entity_dof_start, i_d + 1):
- rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b]
-
- if qd.static(implicit_damping):
- I_d = [i_d, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d
- rigid_global_info.mass_mat_L[i_d, i_d, i_b] = (
- rigid_global_info.mass_mat_L[i_d, i_d, i_b]
- + dofs_info.damping[I_d] * rigid_global_info.substep_dt[None]
- )
- if qd.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast):
- if dofs_state.ctrl_mode[i_d, i_b] <= gs.CTRL_MODE.VELOCITY:
- rigid_global_info.mass_mat_L[i_d, i_d, i_b] = (
- rigid_global_info.mass_mat_L[i_d, i_d, i_b]
- - dofs_info.act_bias[I_d][2] * rigid_global_info.substep_dt[None]
- )
-
- for i_d_ in range(n_dofs):
- i_d = entity_dof_end - i_d_ - 1
- D_inv = 1.0 / rigid_global_info.mass_mat_L[i_d, i_d, i_b]
- rigid_global_info.mass_mat_D_inv[i_d, i_b] = D_inv
-
- for j_d_ in range(i_d - entity_dof_start):
- j_d = i_d - j_d_ - 1
- a = rigid_global_info.mass_mat_L[i_d, j_d, i_b] * D_inv
- for k_d in range(entity_dof_start, j_d + 1):
- rigid_global_info.mass_mat_L[j_d, k_d, i_b] -= (
- a * rigid_global_info.mass_mat_L[i_d, k_d, i_b]
- )
- rigid_global_info.mass_mat_L[i_d, j_d, i_b] = a
-
- # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them.
- rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0
- else:
- BLOCK_DIM = qd.static(32)
- MAX_DOFS_PER_ENTITY = qd.static(static_rigid_sim_config.tiled_n_dofs_per_entity)
- WARP_SIZE = qd.static(32)
-
- qd.loop_config(name="factor_mass", block_dim=BLOCK_DIM)
- for i in range(n_entities * _B * BLOCK_DIM):
- tid = i % BLOCK_DIM
- i_e = (i // BLOCK_DIM) % n_entities
- i_b = i // (BLOCK_DIM * n_entities)
- if i_b >= _B:
- continue
-
- if rigid_global_info.mass_mat_mask[i_e, i_b]:
- entity_dof_start = entities_info.dof_start[i_e]
- entity_dof_end = entities_info.dof_end[i_e]
- n_dofs = entities_info.n_dofs[i_e]
- n_lower_tri = n_dofs * (n_dofs + 1) // 2
-
- mass_mat = qd.simt.block.SharedArray((MAX_DOFS_PER_ENTITY, MAX_DOFS_PER_ENTITY + 1), gs.qd_float)
-
- i_pair = tid
- while i_pair < n_lower_tri:
- i_d_ = qd.cast((qd.sqrt(8 * i_pair + 1) - 1) // 2, qd.i32)
- j_d_ = i_pair - i_d_ * (i_d_ + 1) // 2
- i_d = entity_dof_start + i_d_
- j_d = entity_dof_start + j_d_
- mass_mat[i_d_, j_d_] = rigid_global_info.mass_mat[i_d, j_d, i_b]
- i_pair = i_pair + BLOCK_DIM
- qd.simt.block.sync()
-
- if qd.static(implicit_damping):
- i_d_ = tid
- while i_d_ < n_dofs:
- i_d = entity_dof_start + i_d_
- I_d = [i_d, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d
- mass_mat[i_d_, i_d_] = (
- mass_mat[i_d_, i_d_] + dofs_info.damping[I_d] * rigid_global_info.substep_dt[None]
- )
- if qd.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast):
- if dofs_state.ctrl_mode[i_d, i_b] <= gs.CTRL_MODE.VELOCITY:
- mass_mat[i_d_, i_d_] = (
- mass_mat[i_d_, i_d_]
- - dofs_info.act_bias[I_d][2] * rigid_global_info.substep_dt[None]
- )
- i_d_ = i_d_ + BLOCK_DIM
- qd.simt.block.sync()
-
- for j in range(n_dofs):
- i_d_ = n_dofs - j - 1
- i_d = entity_dof_end - j - 1
-
- D_inv = 1.0 / mass_mat[i_d_, i_d_]
- if tid == 0:
- rigid_global_info.mass_mat_D_inv[i_d, i_b] = D_inv
- # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them.
- rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0
-
- j_d_ = i_d_ - 1 - tid
- while j_d_ >= 0:
- a = mass_mat[i_d_, j_d_] * D_inv
- for k_d in range(j_d_ + 1):
- mass_mat[j_d_, k_d] = mass_mat[j_d_, k_d] - a * mass_mat[i_d_, k_d]
- mass_mat[i_d_, j_d_] = a
- j_d_ = j_d_ - BLOCK_DIM
- if qd.static(static_rigid_sim_config.backend == gs.cuda):
- if i_d_ <= WARP_SIZE:
- qd.simt.warp.sync(qd.u32(0xFFFFFFFF))
- else:
- qd.simt.block.sync()
- else:
- qd.simt.block.sync()
+ n_entities = entities_info.n_links.shape[0]
+ _B = dofs_state.ctrl_mode.shape[1]
- i_pair = tid
- n_strict_lower_tri = n_dofs * (n_dofs - 1) // 2
- while i_pair < n_strict_lower_tri:
- i_d_ = qd.cast((qd.sqrt(8 * i_pair + 1) + 1) // 2, qd.i32)
- j_d_ = i_pair - i_d_ * (i_d_ - 1) // 2
- i_d = entity_dof_start + i_d_
- j_d = entity_dof_start + j_d_
- rigid_global_info.mass_mat_L[i_d, j_d, i_b] = mass_mat[i_d_, j_d_]
- i_pair = i_pair + BLOCK_DIM
- else:
- # Cholesky decomposition that has safe access pattern and robust handling of divide by zero for AD. Even though
- # it is logically equivalent to the above block, it shows slightly numerical difference in the result, and thus
- # it fails for a unit test ("test_urdf_rope"), while passing all the others. TODO: Investigate if we can fix this
- # and only use this block.
-
- # Assume this is the outermost loop
- qd.loop_config(serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL))
- for i_e, i_b in qd.ndrange(entities_info.n_links.shape[0], dofs_state.ctrl_mode.shape[1]):
+ if qd.static(
+ not static_rigid_sim_config.enable_tiled_cholesky_mass_matrix or static_rigid_sim_config.backend == gs.cpu
+ ):
+ qd.loop_config(name="factor_mass", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
+ for i_e, i_b in qd.ndrange(n_entities, _B):
if rigid_global_info.mass_mat_mask[i_e, i_b]:
- EPS = rigid_global_info.EPS[None]
-
entity_dof_start = entities_info.dof_start[i_e]
entity_dof_end = entities_info.dof_end[i_e]
n_dofs = entities_info.n_dofs[i_e]
- for i_d0 in range(n_dofs):
- i_d = entity_dof_start + i_d0
- i_pr = (entity_dof_start + entity_dof_end - 1) - i_d
+ for i_d in range(entity_dof_start, entity_dof_end):
for j_d in range(entity_dof_start, i_d + 1):
- j_pr = (entity_dof_start + entity_dof_end - 1) - j_d
- rigid_global_info.mass_mat_L_bw[0, i_pr, j_pr, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b]
- rigid_global_info.mass_mat_L_bw[0, j_pr, i_pr, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b]
+ rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b]
if qd.static(implicit_damping):
I_d = [i_d, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d
- qd.atomic_add(
- rigid_global_info.mass_mat_L_bw[0, i_pr, i_pr, i_b],
- (dofs_info.damping[I_d] * rigid_global_info.substep_dt[None]),
+ rigid_global_info.mass_mat_L[i_d, i_d, i_b] = (
+ rigid_global_info.mass_mat_L[i_d, i_d, i_b]
+ + dofs_info.damping[I_d] * rigid_global_info.substep_dt[None]
)
if qd.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast):
if dofs_state.ctrl_mode[i_d, i_b] <= gs.CTRL_MODE.VELOCITY:
- qd.atomic_add(
- rigid_global_info.mass_mat_L_bw[0, i_pr, i_pr, i_b],
- -dofs_info.act_bias[I_d][2] * rigid_global_info.substep_dt[None],
+ rigid_global_info.mass_mat_L[i_d, i_d, i_b] = (
+ rigid_global_info.mass_mat_L[i_d, i_d, i_b]
+ - dofs_info.act_bias[I_d][2] * rigid_global_info.substep_dt[None]
)
- # Cholesky-Banachiewicz algorithm (in the perturbed indices), access pattern is safe for autodiff
- # https://en.wikipedia.org/wiki/Cholesky_decomposition
- for p_i0 in range(n_dofs):
- for p_j0 in range(p_i0 + 1):
- # j_pr <= i_pr
- i_pr = entity_dof_start + p_i0
- j_pr = entity_dof_start + p_j0
-
- sum = gs.qd_float(0.0)
- for p_k0 in range(p_j0):
- # k_pr < j_pr
- k_pr = entity_dof_start + p_k0
- sum = sum + (
- rigid_global_info.mass_mat_L_bw[1, i_pr, k_pr, i_b]
- * rigid_global_info.mass_mat_L_bw[1, j_pr, k_pr, i_b]
+ for i_d_ in range(n_dofs):
+ i_d = entity_dof_end - i_d_ - 1
+ D_inv = 1.0 / rigid_global_info.mass_mat_L[i_d, i_d, i_b]
+ rigid_global_info.mass_mat_D_inv[i_d, i_b] = D_inv
+
+ for j_d_ in range(i_d - entity_dof_start):
+ j_d = i_d - j_d_ - 1
+ a = rigid_global_info.mass_mat_L[i_d, j_d, i_b] * D_inv
+ for k_d in range(entity_dof_start, j_d + 1):
+ rigid_global_info.mass_mat_L[j_d, k_d, i_b] -= (
+ a * rigid_global_info.mass_mat_L[i_d, k_d, i_b]
)
+ rigid_global_info.mass_mat_L[i_d, j_d, i_b] = a
- a = rigid_global_info.mass_mat_L_bw[0, i_pr, j_pr, i_b] - sum
- b = qd.math.clamp(
- rigid_global_info.mass_mat_L_bw[1, j_pr, j_pr, i_b],
- EPS,
- qd.math.inf,
+ # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them.
+ rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0
+ else:
+ BLOCK_DIM = qd.static(32)
+ MAX_DOFS_PER_ENTITY = qd.static(static_rigid_sim_config.tiled_n_dofs_per_entity)
+ WARP_SIZE = qd.static(32)
+
+ qd.loop_config(name="factor_mass", block_dim=BLOCK_DIM)
+ for i in range(n_entities * _B * BLOCK_DIM):
+ tid = i % BLOCK_DIM
+ i_e = (i // BLOCK_DIM) % n_entities
+ i_b = i // (BLOCK_DIM * n_entities)
+ if i_b >= _B:
+ continue
+
+ if rigid_global_info.mass_mat_mask[i_e, i_b]:
+ entity_dof_start = entities_info.dof_start[i_e]
+ entity_dof_end = entities_info.dof_end[i_e]
+ n_dofs = entities_info.n_dofs[i_e]
+ n_lower_tri = n_dofs * (n_dofs + 1) // 2
+
+ mass_mat = qd.simt.block.SharedArray((MAX_DOFS_PER_ENTITY, MAX_DOFS_PER_ENTITY + 1), gs.qd_float)
+
+ i_pair = tid
+ while i_pair < n_lower_tri:
+ i_d_ = qd.cast((qd.sqrt(8 * i_pair + 1) - 1) // 2, qd.i32)
+ j_d_ = i_pair - i_d_ * (i_d_ + 1) // 2
+ i_d = entity_dof_start + i_d_
+ j_d = entity_dof_start + j_d_
+ mass_mat[i_d_, j_d_] = rigid_global_info.mass_mat[i_d, j_d, i_b]
+ i_pair = i_pair + BLOCK_DIM
+ qd.simt.block.sync()
+
+ if qd.static(implicit_damping):
+ i_d_ = tid
+ while i_d_ < n_dofs:
+ i_d = entity_dof_start + i_d_
+ I_d = [i_d, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d
+ mass_mat[i_d_, i_d_] = (
+ mass_mat[i_d_, i_d_] + dofs_info.damping[I_d] * rigid_global_info.substep_dt[None]
)
- if p_i0 == p_j0:
- rigid_global_info.mass_mat_L_bw[1, i_pr, j_pr, i_b] = qd.sqrt(
- qd.math.clamp(a, EPS, qd.math.inf)
- )
- else:
- rigid_global_info.mass_mat_L_bw[1, i_pr, j_pr, i_b] = a / b
+ if qd.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast):
+ if dofs_state.ctrl_mode[i_d, i_b] <= gs.CTRL_MODE.VELOCITY:
+ mass_mat[i_d_, i_d_] = (
+ mass_mat[i_d_, i_d_]
+ - dofs_info.act_bias[I_d][2] * rigid_global_info.substep_dt[None]
+ )
+ i_d_ = i_d_ + BLOCK_DIM
+ qd.simt.block.sync()
- for i_d0 in range(n_dofs):
- for i_d1 in range(i_d0 + 1):
- i_d = entity_dof_start + i_d0
- j_d = entity_dof_start + i_d1
- i_pr = (entity_dof_start + entity_dof_end - 1) - i_d
- j_pr = (entity_dof_start + entity_dof_end - 1) - j_d
+ for j in range(n_dofs):
+ i_d_ = n_dofs - j - 1
+ i_d = entity_dof_end - j - 1
- a = rigid_global_info.mass_mat_L_bw[1, i_pr, i_pr, i_b]
- rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat_L_bw[
- 1, j_pr, i_pr, i_b
- ] / qd.math.clamp(a, EPS, qd.math.inf)
+ D_inv = 1.0 / mass_mat[i_d_, i_d_]
+ if tid == 0:
+ rigid_global_info.mass_mat_D_inv[i_d, i_b] = D_inv
+ # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them.
+ rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0
+
+ j_d_ = i_d_ - 1 - tid
+ while j_d_ >= 0:
+ a = mass_mat[i_d_, j_d_] * D_inv
+ for k_d in range(j_d_ + 1):
+ mass_mat[j_d_, k_d] = mass_mat[j_d_, k_d] - a * mass_mat[i_d_, k_d]
+ mass_mat[i_d_, j_d_] = a
+ j_d_ = j_d_ - BLOCK_DIM
+ if qd.static(static_rigid_sim_config.backend == gs.cuda):
+ if i_d_ <= WARP_SIZE:
+ qd.simt.warp.sync(qd.u32(0xFFFFFFFF))
+ else:
+ qd.simt.block.sync()
+ else:
+ qd.simt.block.sync()
- if i_d == j_d:
- rigid_global_info.mass_mat_D_inv[i_d, i_b] = 1.0 / (qd.math.clamp(a**2, EPS, qd.math.inf))
+ i_pair = tid
+ n_strict_lower_tri = n_dofs * (n_dofs - 1) // 2
+ while i_pair < n_strict_lower_tri:
+ i_d_ = qd.cast((qd.sqrt(8 * i_pair + 1) + 1) // 2, qd.i32)
+ j_d_ = i_pair - i_d_ * (i_d_ - 1) // 2
+ i_d = entity_dof_start + i_d_
+ j_d = entity_dof_start + j_d_
+ rigid_global_info.mass_mat_L[i_d, j_d, i_b] = mass_mat[i_d_, j_d_]
+ i_pair = i_pair + BLOCK_DIM
@qd.func
@@ -706,55 +613,36 @@ def func_solve_mass_entity(
i_b: qd.int32,
vec: qd.Tensor,
out: qd.Tensor,
- out_bw: qd.template(),
entities_info: array_class.EntitiesInfo,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
- is_backward: qd.template(),
):
- BW = qd.static(is_backward)
-
if rigid_global_info.mass_mat_mask[i_e, i_b]:
entity_dof_start = entities_info.dof_start[i_e]
entity_dof_end = entities_info.dof_end[i_e]
n_dofs = entities_info.n_dofs[i_e]
- # Step 1: Solve w st. L^T @ w = y
+ # Step 1: Solve w st. L^T @ w = y. Cross-iter same-buffer read of
+ # `out[j_d]` (j > i) is fine for forward execution — the entries it
+ # reads were finalized in earlier (larger-i_d) iterations. The
+ # reverse of this kernel is never auto-generated; the backward is
+ # handled by `kernel_manual_compute_qacc_bw` via IFT.
for i_d_ in range(n_dofs):
i_d = entity_dof_end - i_d_ - 1
curr_out = vec[i_d, i_b]
- if qd.static(BW):
- out_bw[0, i_d, i_b] = vec[i_d, i_b]
-
for j_d in range(i_d + 1, entity_dof_end):
- # Since we read out[j_d, i_b], and j_d > i_d, which means that out[j_d, i_b] is already
- # finalized at this point, we don't need to care about AD mutation rule.
- if qd.static(BW):
- out_bw[0, i_d, i_b] = (
- out_bw[0, i_d, i_b] - rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out_bw[0, j_d, i_b]
- )
- else:
- curr_out = curr_out - rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b]
-
- if qd.static(not BW):
- out[i_d, i_b] = curr_out
+ curr_out = curr_out - rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b]
+ out[i_d, i_b] = curr_out
# Step 2: z = D^{-1} w
for i_d in range(entity_dof_start, entity_dof_end):
- if qd.static(BW):
- out_bw[1, i_d, i_b] = out_bw[0, i_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b]
- else:
- out[i_d, i_b] = out[i_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b]
+ out[i_d, i_b] = out[i_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b]
# Step 3: Solve x st. L @ x = z
for i_d in range(entity_dof_start, entity_dof_end):
curr_out = out[i_d, i_b]
- if qd.static(BW):
- curr_out = out_bw[1, i_d, i_b]
-
for j_d in range(entity_dof_start, i_d):
curr_out = curr_out - rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b]
-
out[i_d, i_b] = curr_out
@@ -763,15 +651,10 @@ def func_solve_mass_batch(
i_b: qd.int32,
vec: qd.Tensor,
out: qd.Tensor,
- out_bw: qd.template(),
entities_info: array_class.EntitiesInfo,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
- is_backward: qd.template(),
):
- BW = qd.static(is_backward)
-
- # This loop is considered an inner loop
qd.loop_config(serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL))
for i_0 in (
range(rigid_global_info.n_awake_entities[i_b])
@@ -779,27 +662,20 @@ def func_solve_mass_batch(
else range(entities_info.n_links.shape[0])
):
i_e = rigid_global_info.awake_entities[i_0, i_b] if qd.static(static_rigid_sim_config.use_hibernation) else i_0
- func_solve_mass_entity(
- i_e, i_b, vec, out, out_bw, entities_info, rigid_global_info, static_rigid_sim_config, is_backward
- )
+ func_solve_mass_entity(i_e, i_b, vec, out, entities_info, rigid_global_info, static_rigid_sim_config)
@qd.func
def func_solve_mass(
vec: qd.Tensor,
out: qd.Tensor,
- out_bw: qd.template(), # None in forward mode, real tensor in backward mode
entities_info: array_class.EntitiesInfo,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
- is_backward: qd.template(),
):
- # This loop must be the outermost loop to be differentiable
qd.loop_config(serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL))
for i_e, i_b in qd.ndrange(entities_info.n_links.shape[0], out.shape[1]):
- func_solve_mass_entity(
- i_e, i_b, vec, out, out_bw, entities_info, rigid_global_info, static_rigid_sim_config, is_backward
- )
+ func_solve_mass_entity(i_e, i_b, vec, out, entities_info, rigid_global_info, static_rigid_sim_config)
@qd.func
@@ -1257,11 +1133,9 @@ def func_compute_qacc(
func_solve_mass(
vec=dofs_state.force,
out=dofs_state.acc_smooth,
- out_bw=dofs_state.acc_smooth_bw,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
- is_backward=is_backward,
)
# Assume this is the outermost loop
@@ -1371,11 +1245,7 @@ def func_integrate(
dofs_state.vel_next[dof_start + 2, i_b],
]
)
- # Backward pass requires atomic add
- if qd.static(BW):
- qd.atomic_add(pos, vel * rigid_global_info.substep_dt[None])
- else:
- pos = pos + vel * rigid_global_info.substep_dt[None]
+ pos += vel * rigid_global_info.substep_dt[None]
for j in qd.static(range(3)):
rigid_global_info.qpos_next[q_start + j, i_b] = pos[j]
if joint_type == gs.JOINT_TYPE.SPHERICAL or joint_type == gs.JOINT_TYPE.FREE:
@@ -1427,6 +1297,10 @@ def kernel_forward_dynamics_without_qacc(
contact_island_state: array_class.ContactIslandState,
is_backward: qd.template(),
):
+ # Backward-only kernel. `func_factor_mass` is omitted: its reverse is
+ # unneeded since `kernel_manual_compute_qacc_bw` seeds `mass_mat.grad`
+ # directly via IFT (skipping the LDLT factor chain). `func_compute_mass_matrix`
+ # is kept so Quadrants AD auto-reverses `mass_mat -> links_state.{pos,quat}`.
func_compute_mass_matrix(
implicit_damping=qd.static(static_rigid_sim_config.integrator == gs.integrator.approximate_implicitfast),
links_state=links_state,
@@ -1438,15 +1312,6 @@ def kernel_forward_dynamics_without_qacc(
static_rigid_sim_config=static_rigid_sim_config,
is_backward=is_backward,
)
- func_factor_mass(
- implicit_damping=False,
- entities_info=entities_info,
- dofs_state=dofs_state,
- dofs_info=dofs_info,
- rigid_global_info=rigid_global_info,
- static_rigid_sim_config=static_rigid_sim_config,
- is_backward=is_backward,
- )
func_torque_and_passive_force(
entities_state=entities_state,
entities_info=entities_info,
@@ -1539,16 +1404,13 @@ def func_implicit_damping(
dofs_info=dofs_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
- is_backward=is_backward,
)
func_solve_mass(
vec=dofs_state.force,
out=dofs_state.acc,
- out_bw=dofs_state.acc_bw,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
- is_backward=is_backward,
)
# Disable pre-computed factorization mask right away
diff --git a/genesis/engine/solvers/rigid/abd/forward_kinematics.py b/genesis/engine/solvers/rigid/abd/forward_kinematics.py
index db7b17bfb8..61caf9e438 100644
--- a/genesis/engine/solvers/rigid/abd/forward_kinematics.py
+++ b/genesis/engine/solvers/rigid/abd/forward_kinematics.py
@@ -638,6 +638,12 @@ def func_forward_kinematics_entity(
+ joints_state.xaxis[i_j, i_b] * dofs_state.pos[dof_start, i_b]
)
pos = W(links_state.pos_bw, next_I, pos_, BW)
+ # Prismatic doesn't rotate the link, but the per-joint cache (quat_bw) still
+ # needs `next_I` populated — otherwise the final `R(quat_bw, I_jf, ...)` at the
+ # end of this function reads uninitialised memory in the backward-mode kernel,
+ # which surfaces as NaN gradients on qpos. Commit the unchanged quat to the
+ # next slot so the cache chain is contiguous through the loop iteration.
+ quat = W(links_state.quat_bw, next_I, quat, BW)
# Skip link pose update for fixed root links to let users manually overwrite them
I_jf = (i_l, 0 if qd.static(not BW) else n_joints, i_b)
@@ -1596,3 +1602,103 @@ def kernel_update_cartesian_space(
force_update_fixed_geoms=force_update_fixed_geoms,
is_backward=is_backward,
)
+
+
+# ---------------------------------------------------------------------------
+# Standalone forward-replay kernels for the `update_cartesian_space` (UCS)
+# sub-stages, used only in the backward pass (`substep_pre_coupling_grad`).
+#
+# The backward unroll replays each UCS sub-step with `is_backward=True` (static
+# loops) and then reverses it. The reverse is NOT uniform across sub-steps:
+# * COM-links and geom-pose updates reverse cleanly through Quadrants
+# autograd, so their replay kernel's `.grad` is called directly.
+# * FK and forward-velocity are reversed *manually* (`kernel_manual_forward_kinematics_bw`
+# / `kernel_manual_forward_velocity_bw`) because autograd silently drops
+# their gradient.
+# Splitting UCS into one standalone kernel per stage lets the backward mix
+# autograd `.grad` and manual reverses stage-by-stage.
+#
+# TODO: once every UCS sub-function reverses correctly under autograd, drop
+# these standalone kernels (and the manual reverses) and differentiate
+# `kernel_update_cartesian_space` as a whole.
+# ---------------------------------------------------------------------------
+@qd.kernel(fastcache=True)
+def kernel_forward_kinematics_replay(
+ envs_idx: qd.types.ndarray(),
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ joints_state: array_class.JointsState,
+ joints_info: array_class.JointsInfo,
+ dofs_state: array_class.DofsState,
+ dofs_info: array_class.DofsInfo,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: qd.template(),
+ is_backward: qd.template(),
+):
+ for i_b_ in range(envs_idx.shape[0]):
+ i_b = qd.cast(envs_idx[i_b_], qd.i32)
+ func_forward_kinematics_batch(
+ i_b=i_b,
+ links_state=links_state,
+ links_info=links_info,
+ joints_state=joints_state,
+ joints_info=joints_info,
+ dofs_state=dofs_state,
+ dofs_info=dofs_info,
+ entities_info=entities_info,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ is_backward=is_backward,
+ )
+
+
+@qd.kernel(fastcache=True)
+def kernel_update_geoms_replay(
+ entities_info: array_class.EntitiesInfo,
+ geoms_state: array_class.GeomsState,
+ geoms_info: array_class.GeomsInfo,
+ links_state: array_class.LinksState,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: qd.template(),
+ is_backward: qd.template(),
+):
+ func_update_geoms(
+ entities_info,
+ geoms_state,
+ geoms_info,
+ links_state,
+ rigid_global_info,
+ static_rigid_sim_config,
+ force_update_fixed_geoms=False,
+ is_backward=is_backward,
+ )
+
+
+@qd.kernel(fastcache=True)
+def kernel_COM_links_replay(
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ joints_state: array_class.JointsState,
+ joints_info: array_class.JointsInfo,
+ dofs_state: array_class.DofsState,
+ dofs_info: array_class.DofsInfo,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: qd.template(),
+ is_backward: qd.template(),
+):
+ for i_b in range(links_state.pos.shape[1]):
+ func_COM_links(
+ i_b=i_b,
+ links_state=links_state,
+ links_info=links_info,
+ joints_state=joints_state,
+ joints_info=joints_info,
+ dofs_state=dofs_state,
+ dofs_info=dofs_info,
+ entities_info=entities_info,
+ rigid_global_info=rigid_global_info,
+ static_rigid_sim_config=static_rigid_sim_config,
+ is_backward=is_backward,
+ )
diff --git a/genesis/engine/solvers/rigid/abd/manual_bw.py b/genesis/engine/solvers/rigid/abd/manual_bw.py
new file mode 100644
index 0000000000..31f9ff97c5
--- /dev/null
+++ b/genesis/engine/solvers/rigid/abd/manual_bw.py
@@ -0,0 +1,769 @@
+"""Manual reverse-mode kernels for the rigid backward pass.
+
+Where Quadrants AD silently drops the reverse chain, we compute
+the Jacobian-transpose by hand instead.
+
+Kernels:
+ - kernel_manual_forward_kinematics_bw : forward-kinematics reverse
+ (link pos/quat grad -> qpos / dofs_vel grad).
+ - kernel_manual_forward_velocity_bw : link-velocity propagation reverse.
+ - kernel_manual_compute_qacc_bw : reverse of `acc = M^{-1} force` via the
+ Implicit Function Theorem (writes force.grad + mass_mat.grad).
+
+Hibernation is not supported and sets the `errno` field
+(`ErrorCode.MANUAL_BW_UNIMPLEMENTED`) rather than silently corrupting
+gradients.
+
+The `@qd.func` helpers below (`d_transform_by_quat__dq`, `d_quat_mul__dlhs` /
+`d_quat_mul__drhs`, `d_rotvec_to_quat__drotvec`, `d_motion_cross_motion`) are
+the hand-written chain-rule derivatives of the corresponding
+`genesis/utils/geom.py` forward functions, used as building blocks above.
+"""
+
+import quadrants as qd
+
+import genesis as gs
+import genesis.utils.array_class as array_class
+import genesis.utils.geom as gu
+
+
+@qd.func
+def d_transform_by_quat__dq(v, quat, out_grad):
+ """Gradient w.r.t. `quat` of `qd_transform_by_quat(v, quat)`.
+
+ Forward (geom.py:294):
+ out[0] = v0·(qw² + qx² - qy² - qz²) + v1·(2qxy - 2qwz) + v2·(2qxz + 2qwy)
+ out[1] = v0·(2qxy + 2qwz) + v1·(qw² - qx² + qy² - qz²) + v2·(2qyz - 2qwx)
+ out[2] = v0·(2qxz - 2qwy) + v1·(2qyz + 2qwx) + v2·(qw² - qx² - qy² + qz²)
+
+ Returns Vec4 = (∂L/∂qw, ∂L/∂qx, ∂L/∂qy, ∂L/∂qz) where
+ L is whatever scalar seeded `out_grad`. (No normalization assumed.)
+ """
+ qw = quat[0]
+ qx = quat[1]
+ qy = quat[2]
+ qz = quat[3]
+ v0 = v[0]
+ v1 = v[1]
+ v2 = v[2]
+ og0 = out_grad[0]
+ og1 = out_grad[1]
+ og2 = out_grad[2]
+
+ # ∂out[0]/∂{w,x,y,z}
+ do0_dqw = 2.0 * (qw * v0 - qz * v1 + qy * v2)
+ do0_dqx = 2.0 * (qx * v0 + qy * v1 + qz * v2)
+ do0_dqy = 2.0 * (-qy * v0 + qx * v1 + qw * v2)
+ do0_dqz = 2.0 * (-qz * v0 - qw * v1 + qx * v2)
+
+ # ∂out[1]/∂{w,x,y,z}
+ do1_dqw = 2.0 * (qz * v0 + qw * v1 - qx * v2)
+ do1_dqx = 2.0 * (qy * v0 - qx * v1 - qw * v2)
+ do1_dqy = 2.0 * (qx * v0 + qy * v1 + qz * v2)
+ do1_dqz = 2.0 * (qw * v0 - qz * v1 + qy * v2)
+
+ # ∂out[2]/∂{w,x,y,z}
+ do2_dqw = 2.0 * (-qy * v0 + qx * v1 + qw * v2)
+ do2_dqx = 2.0 * (qz * v0 + qw * v1 - qx * v2)
+ do2_dqy = 2.0 * (-qw * v0 + qz * v1 - qy * v2)
+ do2_dqz = 2.0 * (qx * v0 + qy * v1 + qz * v2)
+
+ return qd.Vector(
+ [
+ og0 * do0_dqw + og1 * do1_dqw + og2 * do2_dqw,
+ og0 * do0_dqx + og1 * do1_dqx + og2 * do2_dqx,
+ og0 * do0_dqy + og1 * do1_dqy + og2 * do2_dqy,
+ og0 * do0_dqz + og1 * do1_dqz + og2 * do2_dqz,
+ ],
+ dt=gs.qd_float,
+ )
+
+
+@qd.func
+def d_quat_mul__dlhs(a, b, out_grad):
+ """Gradient w.r.t. `a` of `quat_mul(a, b)` (Hamilton convention).
+
+ Forward (geom.py qd_quat_mul):
+ out_w = aw·bw - ax·bx - ay·by - az·bz
+ out_x = aw·bx + ax·bw + ay·bz - az·by
+ out_y = aw·by - ax·bz + ay·bw + az·bx
+ out_z = aw·bz + ax·by - ay·bx + az·bw
+ """
+ bw = b[0]
+ bx = b[1]
+ by = b[2]
+ bz = b[3]
+ ogw = out_grad[0]
+ ogx = out_grad[1]
+ ogy = out_grad[2]
+ ogz = out_grad[3]
+ return qd.Vector(
+ [
+ # ∂L/∂aw
+ ogw * bw + ogx * bx + ogy * by + ogz * bz,
+ # ∂L/∂ax
+ -ogw * bx + ogx * bw - ogy * bz + ogz * by,
+ # ∂L/∂ay
+ -ogw * by + ogx * bz + ogy * bw - ogz * bx,
+ # ∂L/∂az
+ -ogw * bz - ogx * by + ogy * bx + ogz * bw,
+ ],
+ dt=gs.qd_float,
+ )
+
+
+@qd.func
+def d_quat_mul__drhs(a, b, out_grad):
+ """Gradient w.r.t. `b` of `quat_mul(a, b)`."""
+ aw = a[0]
+ ax = a[1]
+ ay = a[2]
+ az = a[3]
+ ogw = out_grad[0]
+ ogx = out_grad[1]
+ ogy = out_grad[2]
+ ogz = out_grad[3]
+ return qd.Vector(
+ [
+ # ∂L/∂bw
+ ogw * aw + ogx * ax + ogy * ay + ogz * az,
+ # ∂L/∂bx
+ -ogw * ax + ogx * aw + ogy * az - ogz * ay,
+ # ∂L/∂by
+ -ogw * ay - ogx * az + ogy * aw + ogz * ax,
+ # ∂L/∂bz
+ -ogw * az + ogx * ay - ogy * ax + ogz * aw,
+ ],
+ dt=gs.qd_float,
+ )
+
+
+@qd.func
+def d_rotvec_to_quat__drotvec(rotvec, eps, quat_grad):
+ """Gradient w.r.t. `rotvec` of `qd_rotvec_to_quat(rotvec, eps)`.
+
+ Forward:
+ thetasq = rx² + ry² + rz²
+ theta_reg = sqrt(thetasq + eps²)
+ c = cos(theta_reg / 2)
+ sinc = sin(theta_reg / 2) / theta_reg
+ quat = (c, sinc·rx, sinc·ry, sinc·rz)
+
+ Backward — by chain rule on theta_reg(rx, ry, rz):
+ ∂theta_reg/∂ri = ri / theta_reg
+ ∂c/∂ri = -0.5·sin(theta_reg/2)·ri/theta_reg
+ = -0.5·(sin·ri)/theta_reg
+ ∂sinc/∂ri = [(0.5·cos(theta_reg/2))/theta_reg
+ - sin(theta_reg/2)/theta_reg²] · ri/theta_reg
+ = ri·(0.5·c/theta_reg² - sinc/theta_reg²)
+
+ ∂quat[0]/∂ri = ∂c/∂ri = -0.5·sin·ri/theta_reg
+ ∂quat[1+j]/∂ri = ∂(sinc·r_j)/∂ri
+ = δ(i,j)·sinc + r_j·∂sinc/∂ri
+
+ So rotvec_grad[i] = quat_grad[0]·(-0.5·sin·ri/theta_reg)
+ + sum_j quat_grad[1+j] · [δ(i,j)·sinc + r_j·∂sinc/∂ri]
+ = quat_grad[0]·(-0.5·sin·ri/theta_reg)
+ + sinc·quat_grad[1+i]
+ + ∂sinc/∂ri · sum_j quat_grad[1+j]·r_j
+ """
+ rx = rotvec[0]
+ ry = rotvec[1]
+ rz = rotvec[2]
+ thetasq = rx * rx + ry * ry + rz * rz
+ theta_reg = qd.sqrt(thetasq + eps * eps)
+ theta_half = 0.5 * theta_reg
+ sin_h = qd.sin(theta_half)
+ cos_h = qd.cos(theta_half)
+ sinc = sin_h / theta_reg
+ # ∂sinc/∂theta_reg = (0.5·cos_h - sinc) / theta_reg
+ dsinc_dtheta = (0.5 * cos_h - sinc) / theta_reg
+
+ qg_w = quat_grad[0]
+ qg_x = quat_grad[1]
+ qg_y = quat_grad[2]
+ qg_z = quat_grad[3]
+
+ # sum_j quat_grad[1+j] · r_j
+ qg_dot_r = qg_x * rx + qg_y * ry + qg_z * rz
+
+ # ∂quat[0]/∂ri = -0.5·sin_h·ri/theta_reg
+ # ∂(sinc·rj)/∂ri = δij·sinc + r_j·(dsinc_dtheta · ri/theta_reg)
+ # so total per i:
+ # ri·[ -0.5·sin_h/theta_reg · qg_w + dsinc_dtheta/theta_reg · qg_dot_r ] + sinc·qg_{x,y,z}[i]
+ coeff = -0.5 * sin_h / theta_reg * qg_w + dsinc_dtheta / theta_reg * qg_dot_r
+ return qd.Vector(
+ [
+ coeff * rx + sinc * qg_x,
+ coeff * ry + sinc * qg_y,
+ coeff * rz + sinc * qg_z,
+ ],
+ dt=gs.qd_float,
+ )
+
+
+@qd.kernel(fastcache=True)
+def kernel_manual_forward_kinematics_bw(
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ joints_state: array_class.JointsState,
+ joints_info: array_class.JointsInfo,
+ dofs_info: array_class.DofsInfo,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: qd.template(),
+ errno: qd.Tensor,
+):
+ """Single-call manual reverse of `kernel_forward_kinematics_replay`.
+ Iterates each entity's links leaf->root in one launch (so a child's
+ `parent.{pos,quat}.grad` write lands before the parent consumes it), and
+ within each link reverses the *full joint chain*.
+
+ A link may carry more than one joint (e.g. a planar floating base =
+ slide-x + slide-z + hinge-y on one body). The forward composes all of them
+ in sequence and caches the per-joint intermediate pose in
+ `links_state.{pos,quat}_bw[i_l, k]`: slot 0 = the "arm base" (parent pose
+ composed with the link's fixed offset), slot k+1 = pose after joint k, slot
+ n_joints = the final link pose. We walk those slots in reverse: seed the
+ grad on the final pose, reverse joint k for k = n_joints-1 .. 0 (each step
+ consumes the grad on slot k+1, emits qpos.grad for that joint and the grad
+ on slot k), then reverse the arm-base composition (slot 0) into the parent's
+ pose grad.
+
+ Each joint also feeds `joints_state.{xanchor,xaxis}` downstream (velocity
+ FK), so we fold those accumulated grads back through slot k as well.
+
+ Joint types: FREE / REVOLUTE / PRISMATIC / SPHERICAL / FIXED.
+ """
+ qd.loop_config(
+ name="manual_fk_only_bw",
+ serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL),
+ )
+ for i_e, i_b in qd.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]):
+ n_in_e = entities_info.n_links[i_e]
+ for i_l_rev in range(n_in_e):
+ i_l = entities_info.link_end[i_e] - 1 - i_l_rev
+ I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l
+ parent_idx = links_info.parent_idx[I_l]
+ n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l]
+
+ # Grad seeded on the final link pose (= slot n_joints). Carried
+ # backward through the joint chain; after the loop it holds the grad
+ # on slot 0 (arm base).
+ g_pos = links_state.pos.grad[i_l, i_b]
+ g_quat = links_state.quat.grad[i_l, i_b]
+
+ for k_rev in range(n_joints):
+ k = n_joints - 1 - k_rev
+ i_j = links_info.joint_start[I_l] + k
+ I_j = [i_j, i_b] if qd.static(static_rigid_sim_config.batch_joints_info) else i_j
+ joint_type = joints_info.type[I_j]
+ q_start = joints_info.q_start[I_j]
+ dof_start = joints_info.dof_start[I_j]
+ I_d = [dof_start, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else dof_start
+
+ # Input pose to joint k (slot k), cached by the forward replay.
+ pos_in = links_state.pos_bw[i_l, k, i_b]
+ quat_in = links_state.quat_bw[i_l, k, i_b]
+ joint_pos_off = joints_info.pos[I_j]
+ xanchor_grad = joints_state.xanchor.grad[i_j, i_b]
+ xaxis_grad = joints_state.xaxis.grad[i_j, i_b]
+
+ if joint_type == gs.JOINT_TYPE.FREE:
+ # Final pose is set absolutely from qpos (slot in unused);
+ # xanchor = qpos[0:3].
+ for j in qd.static(range(3)):
+ rigid_global_info.qpos.grad[q_start + j, i_b] = (
+ rigid_global_info.qpos.grad[q_start + j, i_b] + g_pos[j] + xanchor_grad[j]
+ )
+ for j in qd.static(range(4)):
+ rigid_global_info.qpos.grad[q_start + 3 + j, i_b] = (
+ rigid_global_info.qpos.grad[q_start + 3 + j, i_b] + g_quat[j]
+ )
+ g_pos = qd.Vector([0.0, 0.0, 0.0], dt=gs.qd_float)
+ g_quat = qd.Vector([0.0, 0.0, 0.0, 0.0], dt=gs.qd_float)
+
+ elif joint_type == gs.JOINT_TYPE.REVOLUTE:
+ axis = dofs_info.motion_ang[I_d]
+ angle = rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b]
+ rotvec = axis * angle
+ qloc = gu.qd_rotvec_to_quat(rotvec, rigid_global_info.EPS[None])
+ # quat_out = transform_quat_by_quat(qloc, quat_in) = quat_mul(quat_in, qloc)
+ quat_out = gu.qd_transform_quat_by_quat(qloc, quat_in)
+
+ # pos_out = xanchor - transform(joint_pos_off, quat_out)
+ # xanchor = transform(joint_pos_off, quat_in) + pos_in
+ gq_out = g_quat - d_transform_by_quat__dq(joint_pos_off, quat_out, g_pos)
+ g_qloc = d_quat_mul__drhs(quat_in, qloc, gq_out)
+ g_quat_in_apply = d_quat_mul__dlhs(quat_in, qloc, gq_out)
+ rotvec_grad = d_rotvec_to_quat__drotvec(rotvec, rigid_global_info.EPS[None], g_qloc)
+ angle_grad = axis[0] * rotvec_grad[0] + axis[1] * rotvec_grad[1] + axis[2] * rotvec_grad[2]
+ rigid_global_info.qpos.grad[q_start, i_b] = rigid_global_info.qpos.grad[q_start, i_b] + angle_grad
+
+ # grad into xanchor = g_pos (from pos_out) + downstream xanchor_grad
+ g_xanchor = g_pos + xanchor_grad
+ g_quat_in = (
+ g_quat_in_apply
+ + d_transform_by_quat__dq(joint_pos_off, quat_in, g_xanchor)
+ + d_transform_by_quat__dq(axis, quat_in, xaxis_grad)
+ )
+ g_pos = g_xanchor
+ g_quat = g_quat_in
+
+ elif joint_type == gs.JOINT_TYPE.PRISMATIC:
+ axis = dofs_info.motion_vel[I_d]
+ displacement = rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b]
+ xaxis = gu.qd_transform_by_quat(axis, quat_in)
+ # pos_out = pos_in + xaxis * displacement ; quat_out = quat_in
+ displacement_grad = xaxis[0] * g_pos[0] + xaxis[1] * g_pos[1] + xaxis[2] * g_pos[2]
+ rigid_global_info.qpos.grad[q_start, i_b] = (
+ rigid_global_info.qpos.grad[q_start, i_b] + displacement_grad
+ )
+ g_xaxis = qd.Vector(
+ [
+ g_pos[0] * displacement + xaxis_grad[0],
+ g_pos[1] * displacement + xaxis_grad[1],
+ g_pos[2] * displacement + xaxis_grad[2],
+ ],
+ dt=gs.qd_float,
+ )
+ g_xanchor = g_pos + xanchor_grad
+ g_quat_in = (
+ g_quat
+ + d_transform_by_quat__dq(axis, quat_in, g_xaxis)
+ + d_transform_by_quat__dq(joint_pos_off, quat_in, g_xanchor)
+ )
+ g_pos = g_xanchor
+ g_quat = g_quat_in
+
+ elif joint_type == gs.JOINT_TYPE.SPHERICAL:
+ # qloc = qpos[q_start:q_start+4] (direct); quat_out = quat_mul(quat_in, qloc).
+ # axis defaults to [0,0,1] (xaxis = transform(axis, quat_in)).
+ axis = qd.Vector([0.0, 0.0, 1.0], dt=gs.qd_float)
+ qloc = qd.Vector(
+ [
+ rigid_global_info.qpos[q_start, i_b],
+ rigid_global_info.qpos[q_start + 1, i_b],
+ rigid_global_info.qpos[q_start + 2, i_b],
+ rigid_global_info.qpos[q_start + 3, i_b],
+ ],
+ dt=gs.qd_float,
+ )
+ quat_out = gu.qd_transform_quat_by_quat(qloc, quat_in)
+ gq_out = g_quat - d_transform_by_quat__dq(joint_pos_off, quat_out, g_pos)
+ g_qloc = d_quat_mul__drhs(quat_in, qloc, gq_out)
+ g_quat_in_apply = d_quat_mul__dlhs(quat_in, qloc, gq_out)
+ for j in qd.static(range(4)):
+ rigid_global_info.qpos.grad[q_start + j, i_b] = (
+ rigid_global_info.qpos.grad[q_start + j, i_b] + g_qloc[j]
+ )
+ g_xanchor = g_pos + xanchor_grad
+ g_quat_in = (
+ g_quat_in_apply
+ + d_transform_by_quat__dq(joint_pos_off, quat_in, g_xanchor)
+ + d_transform_by_quat__dq(axis, quat_in, xaxis_grad)
+ )
+ g_pos = g_xanchor
+ g_quat = g_quat_in
+
+ else: # gs.JOINT_TYPE.FIXED — pose passes through unchanged.
+ pass
+
+ for j in qd.static(range(3)):
+ joints_state.xanchor.grad[i_j, i_b][j] = 0.0
+ joints_state.xaxis.grad[i_j, i_b][j] = 0.0
+
+ # Reverse the arm-base composition (slot 0):
+ # arm_base_pos = parent_pos + transform(link_offset_pos, parent_quat)
+ # arm_base_quat = quat_mul(parent_quat, link_offset_quat)
+ # propagating slot-0 grad (g_pos, g_quat) into the parent's pose grad.
+ if parent_idx != -1:
+ parent_quat = links_state.quat[parent_idx, i_b]
+ link_off_pos = links_info.pos[I_l]
+ link_off_quat = links_info.quat[I_l]
+ parent_quat_grad_from_pos = d_transform_by_quat__dq(link_off_pos, parent_quat, g_pos)
+ parent_quat_grad_from_quat = d_quat_mul__dlhs(parent_quat, link_off_quat, g_quat)
+ for j in qd.static(range(3)):
+ links_state.pos.grad[parent_idx, i_b][j] = links_state.pos.grad[parent_idx, i_b][j] + g_pos[j]
+ for j in qd.static(range(4)):
+ links_state.quat.grad[parent_idx, i_b][j] = (
+ links_state.quat.grad[parent_idx, i_b][j]
+ + parent_quat_grad_from_pos[j]
+ + parent_quat_grad_from_quat[j]
+ )
+
+ for j in qd.static(range(3)):
+ links_state.pos.grad[i_l, i_b][j] = 0.0
+ for j in qd.static(range(4)):
+ links_state.quat.grad[i_l, i_b][j] = 0.0
+
+
+@qd.func
+def d_motion_cross_motion(s_ang, s_vel, m_ang, m_vel, ang_g, vel_g):
+ """Reverse of motion_cross_motion(s_ang, s_vel, m_ang, m_vel).
+
+ Forward (geom.py:437):
+ vel = s_ang × m_vel + s_vel × m_ang
+ ang = s_ang × m_ang
+
+ Chain rule (c=a×b ⇒ a.g += b × c.g, b.g += c.g × a):
+ s_ang.g += m_ang × ang.g + m_vel × vel.g
+ s_vel.g += m_ang × vel.g
+ m_ang.g += ang.g × s_ang + vel.g × s_vel
+ m_vel.g += vel.g × s_ang
+
+ Returns (s_ang_g, s_vel_g, m_ang_g, m_vel_g) — additive deltas.
+ """
+ return (
+ m_ang.cross(ang_g) + m_vel.cross(vel_g),
+ m_ang.cross(vel_g),
+ ang_g.cross(s_ang) + vel_g.cross(s_vel),
+ vel_g.cross(s_ang),
+ )
+
+
+@qd.kernel(fastcache=True)
+def kernel_manual_forward_velocity_bw(
+ links_state: array_class.LinksState,
+ links_info: array_class.LinksInfo,
+ joints_info: array_class.JointsInfo,
+ dofs_state: array_class.DofsState,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: qd.template(),
+ errno: qd.Tensor,
+):
+ """Manual reverse of `kernel_forward_velocity` — single-call (no per-link
+ split). Replaces the diagnostic per-link split in `substep_pre_coupling_grad`
+ by computing the cross-link `cd_{vel,ang}[parent_idx]` chain explicitly.
+
+ Inputs (read .grad seeds):
+ - cd_vel.grad[i_l, i_b], cd_ang.grad[i_l, i_b]
+ - cd_vel_bw.grad[i_l, k, i_b], cd_ang_bw.grad[i_l, k, i_b]
+ - cdofd_ang.grad[i_d, i_b], cdofd_vel.grad[i_d, i_b]
+
+ Outputs (accumulated .grad):
+ - dofs_state.vel.grad[i_d, i_b]
+ - dofs_state.cdof_ang.grad[i_d, i_b], dofs_state.cdof_vel.grad[i_d, i_b]
+ - links_state.cd_vel.grad[parent_idx, i_b], links_state.cd_ang.grad[parent_idx, i_b]
+ (cross-link chain — equivalent to forward replay's BW=True
+ `cd_*_bw[i_l, 0] = parent.cd_*`)
+ """
+ qd.loop_config(
+ name="manual_forward_velocity_bw",
+ serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL),
+ )
+ for i_e, i_b in qd.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]):
+ if qd.static(static_rigid_sim_config.use_hibernation):
+ errno[i_b] = errno[i_b] | array_class.ErrorCode.MANUAL_BW_UNIMPLEMENTED
+ else:
+ n_in_e = entities_info.n_links[i_e]
+ # Leaf → root iteration so each link's cd_*_bw[0].grad (which
+ # accumulates into parent.cd_*.grad) is propagated *before* the
+ # parent's own iteration uses it.
+ for i_l_rev in range(n_in_e):
+ i_l = entities_info.link_end[i_e] - 1 - i_l_rev
+ I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l
+ n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l]
+ i_p = links_info.parent_idx[I_l]
+
+ # ── Step 1 reverse: cd_*[i_l].grad → cd_*_bw[i_l, n_joints].grad
+ for k in qd.static(range(3)):
+ links_state.cd_vel_bw.grad[i_l, n_joints, i_b][k] = (
+ links_state.cd_vel_bw.grad[i_l, n_joints, i_b][k] + links_state.cd_vel.grad[i_l, i_b][k]
+ )
+ links_state.cd_ang_bw.grad[i_l, n_joints, i_b][k] = (
+ links_state.cd_ang_bw.grad[i_l, n_joints, i_b][k] + links_state.cd_ang.grad[i_l, i_b][k]
+ )
+ # consume cd_vel/cd_ang.grad[i_l]
+ for k in qd.static(range(3)):
+ links_state.cd_vel.grad[i_l, i_b][k] = 0.0
+ links_state.cd_ang.grad[i_l, i_b][k] = 0.0
+
+ # ── Step 2: iterate joints in reverse
+ for i_j_rev in range(n_joints):
+ i_j_ = n_joints - 1 - i_j_rev
+ i_j = i_j_ + links_info.joint_start[I_l]
+ I_j = [i_j, i_b] if qd.static(static_rigid_sim_config.batch_joints_info) else i_j
+ jt = joints_info.type[I_j]
+ ds = joints_info.dof_start[I_j]
+ de = joints_info.dof_end[I_j]
+ curr_idx = i_j_
+ next_idx = i_j_ + 1
+
+ # ── [d-rev] cd_*_bw[next].grad → cdof_*.grad / vel.grad
+ # Forward (FREE angular: i_3=0..2 at d=ds+3+i_3; else: d in ds..de):
+ # _vel = cdof_vel[d] * vel[d]; atomic_add(cd_vel_bw[next], _vel)
+ # _ang = cdof_ang[d] * vel[d]; atomic_add(cd_ang_bw[next], _ang)
+ cvg_next = links_state.cd_vel_bw.grad[i_l, next_idx, i_b]
+ cag_next = links_state.cd_ang_bw.grad[i_l, next_idx, i_b]
+ if jt == gs.JOINT_TYPE.FREE:
+ for i_3 in qd.static(range(3)):
+ d_i = ds + 3 + i_3
+ v_at_d = dofs_state.vel[d_i, i_b]
+ cdv = dofs_state.cdof_vel[d_i, i_b]
+ cda = dofs_state.cdof_ang[d_i, i_b]
+ for k in qd.static(range(3)):
+ dofs_state.cdof_vel.grad[d_i, i_b][k] = (
+ dofs_state.cdof_vel.grad[d_i, i_b][k] + cvg_next[k] * v_at_d
+ )
+ dofs_state.cdof_ang.grad[d_i, i_b][k] = (
+ dofs_state.cdof_ang.grad[d_i, i_b][k] + cag_next[k] * v_at_d
+ )
+ dot_vel = cdv[0] * cvg_next[0] + cdv[1] * cvg_next[1] + cdv[2] * cvg_next[2]
+ dot_ang = cda[0] * cag_next[0] + cda[1] * cag_next[1] + cda[2] * cag_next[2]
+ dofs_state.vel.grad[d_i, i_b] = dofs_state.vel.grad[d_i, i_b] + dot_vel + dot_ang
+ else:
+ for i_d in range(ds, de):
+ v_at_d = dofs_state.vel[i_d, i_b]
+ cdv = dofs_state.cdof_vel[i_d, i_b]
+ cda = dofs_state.cdof_ang[i_d, i_b]
+ for k in qd.static(range(3)):
+ dofs_state.cdof_vel.grad[i_d, i_b][k] = (
+ dofs_state.cdof_vel.grad[i_d, i_b][k] + cvg_next[k] * v_at_d
+ )
+ dofs_state.cdof_ang.grad[i_d, i_b][k] = (
+ dofs_state.cdof_ang.grad[i_d, i_b][k] + cag_next[k] * v_at_d
+ )
+ dot_vel = cdv[0] * cvg_next[0] + cdv[1] * cvg_next[1] + cdv[2] * cvg_next[2]
+ dot_ang = cda[0] * cag_next[0] + cda[1] * cag_next[1] + cda[2] * cag_next[2]
+ dofs_state.vel.grad[i_d, i_b] = dofs_state.vel.grad[i_d, i_b] + dot_vel + dot_ang
+
+ # ── [c-rev] cd_*_bw[next] = cd_*_bw[curr] → curr.grad += next.grad
+ for k in qd.static(range(3)):
+ links_state.cd_vel_bw.grad[i_l, curr_idx, i_b][k] = (
+ links_state.cd_vel_bw.grad[i_l, curr_idx, i_b][k] + cvg_next[k]
+ )
+ links_state.cd_ang_bw.grad[i_l, curr_idx, i_b][k] = (
+ links_state.cd_ang_bw.grad[i_l, curr_idx, i_b][k] + cag_next[k]
+ )
+ # consume next
+ for k in qd.static(range(3)):
+ links_state.cd_vel_bw.grad[i_l, next_idx, i_b][k] = 0.0
+ links_state.cd_ang_bw.grad[i_l, next_idx, i_b][k] = 0.0
+
+ # ── [b-rev] motion_cross_motion reverse:
+ # Forward: (cdofd_ang[d_i], cdofd_vel[d_i]) =
+ # motion_cross_motion(cd_ang_bw[curr], cd_vel_bw[curr], cdof_ang[d_i], cdof_vel[d_i])
+ # Reverse via d_motion_cross_motion(s_ang, s_vel, m_ang, m_vel, ang_g, vel_g)
+ s_ang_primal = links_state.cd_ang_bw[i_l, curr_idx, i_b]
+ s_vel_primal = links_state.cd_vel_bw[i_l, curr_idx, i_b]
+ if jt == gs.JOINT_TYPE.FREE:
+ # Angular dofs i_3=0..2 at d_i = ds + 3 + i_3 (linear cdofd_* are explicit 0)
+ for i_3 in qd.static(range(3)):
+ d_i = ds + 3 + i_3
+ ang_g = dofs_state.cdofd_ang.grad[d_i, i_b]
+ vel_g = dofs_state.cdofd_vel.grad[d_i, i_b]
+ cda = dofs_state.cdof_ang[d_i, i_b]
+ cdv = dofs_state.cdof_vel[d_i, i_b]
+ s_ang_g, s_vel_g, m_ang_g, m_vel_g = d_motion_cross_motion(
+ s_ang_primal, s_vel_primal, cda, cdv, ang_g, vel_g
+ )
+ for k in qd.static(range(3)):
+ links_state.cd_ang_bw.grad[i_l, curr_idx, i_b][k] = (
+ links_state.cd_ang_bw.grad[i_l, curr_idx, i_b][k] + s_ang_g[k]
+ )
+ links_state.cd_vel_bw.grad[i_l, curr_idx, i_b][k] = (
+ links_state.cd_vel_bw.grad[i_l, curr_idx, i_b][k] + s_vel_g[k]
+ )
+ dofs_state.cdof_ang.grad[d_i, i_b][k] = (
+ dofs_state.cdof_ang.grad[d_i, i_b][k] + m_ang_g[k]
+ )
+ dofs_state.cdof_vel.grad[d_i, i_b][k] = (
+ dofs_state.cdof_vel.grad[d_i, i_b][k] + m_vel_g[k]
+ )
+ # consume cdofd_*.grad[d_i]
+ for k in qd.static(range(3)):
+ dofs_state.cdofd_ang.grad[d_i, i_b][k] = 0.0
+ dofs_state.cdofd_vel.grad[d_i, i_b][k] = 0.0
+ # Linear dofs (i_3=0..2 at d_i = ds + i_3): cdofd_* set to 0
+ # (constant), reverse is no-op; just consume to mirror P8.
+ for i_3 in qd.static(range(3)):
+ d_i = ds + i_3
+ for k in qd.static(range(3)):
+ dofs_state.cdofd_ang.grad[d_i, i_b][k] = 0.0
+ dofs_state.cdofd_vel.grad[d_i, i_b][k] = 0.0
+ else:
+ for i_d in range(ds, de):
+ ang_g = dofs_state.cdofd_ang.grad[i_d, i_b]
+ vel_g = dofs_state.cdofd_vel.grad[i_d, i_b]
+ cda = dofs_state.cdof_ang[i_d, i_b]
+ cdv = dofs_state.cdof_vel[i_d, i_b]
+ s_ang_g, s_vel_g, m_ang_g, m_vel_g = d_motion_cross_motion(
+ s_ang_primal, s_vel_primal, cda, cdv, ang_g, vel_g
+ )
+ for k in qd.static(range(3)):
+ links_state.cd_ang_bw.grad[i_l, curr_idx, i_b][k] = (
+ links_state.cd_ang_bw.grad[i_l, curr_idx, i_b][k] + s_ang_g[k]
+ )
+ links_state.cd_vel_bw.grad[i_l, curr_idx, i_b][k] = (
+ links_state.cd_vel_bw.grad[i_l, curr_idx, i_b][k] + s_vel_g[k]
+ )
+ dofs_state.cdof_ang.grad[i_d, i_b][k] = (
+ dofs_state.cdof_ang.grad[i_d, i_b][k] + m_ang_g[k]
+ )
+ dofs_state.cdof_vel.grad[i_d, i_b][k] = (
+ dofs_state.cdof_vel.grad[i_d, i_b][k] + m_vel_g[k]
+ )
+ for k in qd.static(range(3)):
+ dofs_state.cdofd_ang.grad[i_d, i_b][k] = 0.0
+ dofs_state.cdofd_vel.grad[i_d, i_b][k] = 0.0
+
+ # ── [a-rev] (FREE only) cd_*_bw[curr].grad → linear cdof_*.grad / vel.grad
+ # Forward (FREE linear pre-motion_cross_motion): for i_3=0..2 at d_i = ds + i_3,
+ # _vel = cdof_vel[d_i] * vel[d_i]; atomic_add(cd_vel_bw[curr], _vel)
+ # _ang = cdof_ang[d_i] * vel[d_i]; atomic_add(cd_ang_bw[curr], _ang)
+ # (cdof_vel[linear] = e_i_3 constant; cdof_ang[linear] = 0 constant)
+ if jt == gs.JOINT_TYPE.FREE:
+ cvg_curr = links_state.cd_vel_bw.grad[i_l, curr_idx, i_b]
+ cag_curr = links_state.cd_ang_bw.grad[i_l, curr_idx, i_b]
+ for i_3 in qd.static(range(3)):
+ d_i = ds + i_3
+ v_at_d = dofs_state.vel[d_i, i_b]
+ cdv = dofs_state.cdof_vel[d_i, i_b]
+ cda = dofs_state.cdof_ang[d_i, i_b]
+ for k in qd.static(range(3)):
+ dofs_state.cdof_vel.grad[d_i, i_b][k] = (
+ dofs_state.cdof_vel.grad[d_i, i_b][k] + cvg_curr[k] * v_at_d
+ )
+ dofs_state.cdof_ang.grad[d_i, i_b][k] = (
+ dofs_state.cdof_ang.grad[d_i, i_b][k] + cag_curr[k] * v_at_d
+ )
+ dot_vel = cdv[0] * cvg_curr[0] + cdv[1] * cvg_curr[1] + cdv[2] * cvg_curr[2]
+ dot_ang = cda[0] * cag_curr[0] + cda[1] * cag_curr[1] + cda[2] * cag_curr[2]
+ dofs_state.vel.grad[d_i, i_b] = dofs_state.vel.grad[d_i, i_b] + dot_vel + dot_ang
+
+ # ── Step 1 (initial cvel setup) reverse:
+ # Forward: cd_*_bw[i_l, 0, i_b] = parent.cd_*[i_p, i_b] (if i_p != -1) else 0
+ # Reverse: parent.cd_*.grad[i_p] += cd_*_bw[i_l, 0].grad; consume slot 0
+ slot0_v_g = links_state.cd_vel_bw.grad[i_l, 0, i_b]
+ slot0_a_g = links_state.cd_ang_bw.grad[i_l, 0, i_b]
+ if i_p != -1:
+ for k in qd.static(range(3)):
+ links_state.cd_vel.grad[i_p, i_b][k] = links_state.cd_vel.grad[i_p, i_b][k] + slot0_v_g[k]
+ links_state.cd_ang.grad[i_p, i_b][k] = links_state.cd_ang.grad[i_p, i_b][k] + slot0_a_g[k]
+ # consume slot 0
+ for k in qd.static(range(3)):
+ links_state.cd_vel_bw.grad[i_l, 0, i_b][k] = 0.0
+ links_state.cd_ang_bw.grad[i_l, 0, i_b][k] = 0.0
+
+
+@qd.kernel(fastcache=True)
+def kernel_manual_compute_qacc_bw(
+ dofs_state: array_class.DofsState,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: qd.template(),
+):
+ """Manual backward for `func_compute_qacc` via Implicit Function Theorem.
+
+ Forward chain (`func_compute_qacc`):
+ acc_smooth = M^{-1} . force (via LDLT solve in `func_solve_mass`)
+ acc[i] = acc_smooth[i] (identity copy)
+
+ Reverse chain (manual, by IFT and symmetry of M = L^T D L):
+ acc_smooth.grad += acc.grad (reverse of identity copy)
+ acc.grad = 0 (forward overwrites acc)
+ force_contrib = M^{-1} . acc_smooth.grad
+ force.grad += force_contrib
+ (M is symmetric, so M^{-T} = M^{-1})
+ # Symmetric lower-tri storage of mass_mat: forward only reads
+ # mass_mat[i, j] for i >= j (lower-tri). Each off-diagonal parameter
+ # appears once in the forward, but the implicit symmetry means the
+ # IFT contribution to the lower-tri parameter combines both `(i, j)`
+ # and `(j, i)` chain terms (the upper mirror is logically present
+ # via symmetry of the matrix being inverted).
+ mass_mat[i, i].grad += -force_contrib[i] * acc_smooth[i]
+ mass_mat[i, j].grad += -(force_contrib[i] * acc_smooth[j]
+ + force_contrib[j] * acc_smooth[i]) (i > j)
+
+ LDLT structure (Genesis transposed convention M = L^T D L):
+ Step 1: solve L^T . u = acc_smooth.grad (descending i_d)
+ Step 2: v = D^{-1} . u
+ Step 3: solve L . delta_force_grad = v (ascending i_d)
+
+ Inputs vs outputs (note the asymmetry between the values *read* from
+ `rigid_global_info` and the grad fields *written* into it):
+ Reads:
+ - dofs_state.acc.grad, dofs_state.acc_smooth.grad (seed)
+ - rigid_global_info.mass_mat_L (LDLT solve)
+ - rigid_global_info.mass_mat_D_inv (LDLT solve)
+ - dofs_state.acc_smooth (IFT outer product)
+ Writes:
+ - dofs_state.force.grad (M^{-1} . seed)
+ - rigid_global_info.mass_mat.grad (IFT seed)
+ - dofs_state.acc.grad, dofs_state.acc_smooth.grad (consumed → 0)
+
+ The dense `mass_mat` is *not* read here (the forward already factored it
+ into `mass_mat_L` / `mass_mat_D_inv`), but its `.grad` is the parameter
+ the IFT chain naturally exposes, so this kernel is the only place
+ `mass_mat.grad` gets populated in the backward path.
+ """
+ qd.loop_config(serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL))
+ for i_e, i_b in qd.ndrange(entities_info.n_links.shape[0], dofs_state.force.shape[1]):
+ if rigid_global_info.mass_mat_mask[i_e, i_b]:
+ entity_dof_start = entities_info.dof_start[i_e]
+ entity_dof_end = entities_info.dof_end[i_e]
+ n_dofs = entities_info.n_dofs[i_e]
+
+ # Reverse of `acc[i] = acc_smooth[i]`: drain acc.grad → seed acc_smooth.grad.
+ # Stash the combined seed in `acc_smooth_bw[0]` as the input to the LDLT
+ # reverse solve. Then zero `acc.grad` since the forward copy overwrites
+ # `acc` (its old value is destroyed, so any prior `acc.grad` is consumed).
+ for i_d in range(entity_dof_start, entity_dof_end):
+ dofs_state.acc_smooth_bw[0, i_d, i_b] = (
+ dofs_state.acc_smooth.grad[i_d, i_b] + dofs_state.acc.grad[i_d, i_b]
+ )
+ dofs_state.acc.grad[i_d, i_b] = 0.0
+ dofs_state.acc_smooth.grad[i_d, i_b] = 0.0
+
+ # Step 1: solve L^T . u = seed (input from [0], output to [1])
+ # u[i] = seed[i] - sum_{j>i} L[j,i] * u[j]
+ for i_d_ in range(n_dofs):
+ i_d = entity_dof_end - i_d_ - 1
+ curr = dofs_state.acc_smooth_bw[0, i_d, i_b]
+ for j_d in range(i_d + 1, entity_dof_end):
+ curr = curr - rigid_global_info.mass_mat_L[j_d, i_d, i_b] * dofs_state.acc_smooth_bw[1, j_d, i_b]
+ dofs_state.acc_smooth_bw[1, i_d, i_b] = curr
+
+ # Step 2: v = D^{-1} . u (output to [0], overwriting input)
+ for i_d in range(entity_dof_start, entity_dof_end):
+ dofs_state.acc_smooth_bw[0, i_d, i_b] = (
+ dofs_state.acc_smooth_bw[1, i_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b]
+ )
+
+ # Step 3: solve L . delta = v (input from [0], output to [1])
+ # delta[i] = v[i] - sum_{j j): ∂L/∂M_ij = -(force_contrib[i] · acc_smooth[j]
+ # + force_contrib[j] · acc_smooth[i])
+ # The off-diagonal sum picks up the chain through both the
+ # (i, j) and (j, i) occurrences of the parameter in the symmetric
+ # matrix.
+ for i_d in range(entity_dof_start, entity_dof_end):
+ fi = dofs_state.acc_smooth_bw[1, i_d, i_b]
+ ai = dofs_state.acc_smooth[i_d, i_b]
+ rigid_global_info.mass_mat.grad[i_d, i_d, i_b] = (
+ rigid_global_info.mass_mat.grad[i_d, i_d, i_b] - fi * ai
+ )
+ for j_d in range(entity_dof_start, i_d):
+ fj = dofs_state.acc_smooth_bw[1, j_d, i_b]
+ aj = dofs_state.acc_smooth[j_d, i_b]
+ rigid_global_info.mass_mat.grad[i_d, j_d, i_b] = rigid_global_info.mass_mat.grad[i_d, j_d, i_b] - (
+ fi * aj + fj * ai
+ )
diff --git a/genesis/engine/solvers/rigid/collider/collider.py b/genesis/engine/solvers/rigid/collider/collider.py
index 6a0ad21112..9c8630bee2 100644
--- a/genesis/engine/solvers/rigid/collider/collider.py
+++ b/genesis/engine/solvers/rigid/collider/collider.py
@@ -914,6 +914,18 @@ def detection(self) -> None:
self._collider_static_config,
)
+ # Plane-convex contacts use analytic paths that don't fill
+ # `diff_contact_input`; populate it here so the differentiable
+ # narrow-phase reverse can reconstruct them (see
+ # `kernel_fill_diff_contact_input_plane`).
+ if self._solver._static_rigid_sim_config.requires_grad:
+ narrowphase.kernel_fill_diff_contact_input_plane(
+ self._solver.geoms_state,
+ self._solver.geoms_info,
+ self._solver._static_rigid_sim_config,
+ self._collider_state,
+ )
+
def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch_dim: bool = False):
# Early return if already pre-computed
contact_data = self._contact_data_cache.setdefault((as_tensor, to_torch), {})
@@ -1049,8 +1061,9 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch
def backward(self, dL_dposition, dL_dnormal, dL_dpenetration):
func_set_upstream_grad(dL_dposition, dL_dnormal, dL_dpenetration, self._collider_state)
+ self.backward_narrowphase()
- # Compute gradient
+ def backward_narrowphase(self):
func_narrow_phase_diff_convex_vs_convex.grad(
self._solver.geoms_state,
self._solver.geoms_info,
diff --git a/genesis/engine/solvers/rigid/collider/diff_gjk.py b/genesis/engine/solvers/rigid/collider/diff_gjk.py
index 336a721eff..4e0726ac8c 100644
--- a/genesis/engine/solvers/rigid/collider/diff_gjk.py
+++ b/genesis/engine/solvers/rigid/collider/diff_gjk.py
@@ -862,6 +862,57 @@ def func_differentiable_contact(
return contact_pos, contact_normal, penetration, weight
+@qd.func
+def func_differentiable_plane_contact(
+ geoms_state: array_class.GeomsState,
+ geoms_info: array_class.GeomsInfo,
+ diff_contact_input: array_class.DiffContactInput,
+ i_ga,
+ i_gb,
+ i_b,
+ i_c,
+):
+ """Differentiable plane-convex contact reconstruction.
+
+ Mirrors the analytic plane branch of `func_convex_convex_contact`:
+ normal = -normalize(R(quat_plane) @ plane_local_dir)
+ v_world = R(quat_convex) @ core_local + pos_convex + radius * normal
+ penetration = normal . (v_world - pos_plane)
+ contact_pos = v_world - 0.5 * penetration * normal
+
+ `i_ga` is the PLANE geom, `i_gb` the convex geom. `core_local` (box vertex /
+ sphere center / capsule nearest endpoint, in convex-local frame) is the
+ stored non-diff witness; `radius` and `plane_local_dir` come from geoms_info.
+ Gradients flow to both geom poses through `geoms_state.{pos,quat}`. For a
+ sphere `core_local` is the local origin, so the orientation gradient is
+ naturally zero (the contact is rotation-invariant), matching the forward.
+ """
+ trans_plane = geoms_state.pos[i_ga, i_b]
+ quat_plane = geoms_state.quat[i_ga, i_b]
+ trans_convex = geoms_state.pos[i_gb, i_b]
+ quat_convex = geoms_state.quat[i_gb, i_b]
+
+ plane_dir = gs.qd_vec3(geoms_info.data[i_ga][0], geoms_info.data[i_ga][1], geoms_info.data[i_ga][2])
+ plane_dir = gu.qd_transform_by_quat(plane_dir, quat_plane)
+ normal = -plane_dir.normalized()
+
+ radius = gs.qd_float(0.0)
+ convex_type = geoms_info.type[i_gb]
+ if convex_type == gs.GEOM_TYPE.SPHERE:
+ radius = geoms_info.data[i_gb][0]
+ elif convex_type == gs.GEOM_TYPE.CAPSULE:
+ radius = geoms_info.data[i_gb][0]
+
+ core_local = diff_contact_input.core_local[i_b, i_c]
+ core_world = gu.qd_transform_by_trans_quat(core_local, trans_convex, quat_convex)
+ v_world = core_world + radius * normal
+
+ penetration = normal.dot(v_world - trans_plane)
+ contact_pos = v_world - 0.5 * penetration * normal
+ weight = gs.qd_float(1.0)
+ return contact_pos, normal, penetration, weight
+
+
@qd.func
def func_plane_normal(v1, v2, v3):
"""
diff --git a/genesis/engine/solvers/rigid/collider/narrowphase.py b/genesis/engine/solvers/rigid/collider/narrowphase.py
index 696dcc1e61..07c3f3d122 100644
--- a/genesis/engine/solvers/rigid/collider/narrowphase.py
+++ b/genesis/engine/solvers/rigid/collider/narrowphase.py
@@ -2705,9 +2705,18 @@ def func_narrow_phase_diff_convex_vs_convex(
if is_ref:
ref_penetration = -1.0
- contact_pos, contact_normal, penetration, weight = diff_gjk.func_differentiable_contact(
- geoms_state, diff_contact_input, gjk_info, i_ga, i_gb, i_b, i_c, ref_penetration
- )
+ contact_pos = gs.qd_vec3(0.0, 0.0, 0.0)
+ contact_normal = gs.qd_vec3(0.0, 0.0, 0.0)
+ penetration = gs.qd_float(0.0)
+ weight = gs.qd_float(0.0)
+ if geoms_info.type[i_ga] == gs.GEOM_TYPE.PLANE:
+ contact_pos, contact_normal, penetration, weight = diff_gjk.func_differentiable_plane_contact(
+ geoms_state, geoms_info, diff_contact_input, i_ga, i_gb, i_b, i_c
+ )
+ else:
+ contact_pos, contact_normal, penetration, weight = diff_gjk.func_differentiable_contact(
+ geoms_state, diff_contact_input, gjk_info, i_ga, i_gb, i_b, i_c, ref_penetration
+ )
collider_state.diff_contact_input.ref_penetration[i_b, i_c] = penetration
func_set_contact(
@@ -2755,6 +2764,60 @@ def func_narrow_phase_diff_convex_vs_convex(
)
+@qd.kernel(fastcache=True)
+def kernel_fill_diff_contact_input_plane(
+ geoms_state: array_class.GeomsState,
+ geoms_info: array_class.GeomsInfo,
+ static_rigid_sim_config: qd.template(),
+ collider_state: array_class.ColliderState,
+):
+ """Populate `diff_contact_input` for plane-convex contacts (forward, non-diff).
+
+ The analytic plane paths (`func_plane_box_contact`,
+ `func_convex_convex_contact`'s plane branch) do NOT fill `diff_contact_input`,
+ so the differentiable narrow-phase reverse would have nothing to reconstruct.
+ Both paths use the same convention `contact_pos = v1 - 0.5*pen*normal` with
+ `normal = -normalize(R(quat_plane) @ plane_local_dir)`, so the convex support
+ point is recovered as `v1 = contact_pos + 0.5*pen*normal`, and its non-diff
+ "core" (vertex / sphere center / capsule endpoint) as `v1 - radius*normal`,
+ stored in the convex's local frame. PLANE is GEOM_TYPE 0 so it is always
+ `geom_a` after the canonical type-ordered swap.
+ """
+ _B = collider_state.active_buffer.shape[1]
+ qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
+ for i_c, i_b in qd.ndrange(collider_state.contact_data.pos.shape[0], _B):
+ if i_c < collider_state.n_contacts[i_b]:
+ i_ga = collider_state.contact_data.geom_a[i_c, i_b]
+ i_gb = collider_state.contact_data.geom_b[i_c, i_b]
+ if geoms_info.type[i_ga] == gs.GEOM_TYPE.PLANE:
+ quat_plane = geoms_state.quat[i_ga, i_b]
+ trans_convex = geoms_state.pos[i_gb, i_b]
+ quat_convex = geoms_state.quat[i_gb, i_b]
+
+ plane_dir = gs.qd_vec3(geoms_info.data[i_ga][0], geoms_info.data[i_ga][1], geoms_info.data[i_ga][2])
+ plane_dir = gu.qd_transform_by_quat(plane_dir, quat_plane)
+ normal = -plane_dir.normalized()
+
+ radius = gs.qd_float(0.0)
+ ctype = geoms_info.type[i_gb]
+ if ctype == gs.GEOM_TYPE.SPHERE:
+ radius = geoms_info.data[i_gb][0]
+ elif ctype == gs.GEOM_TYPE.CAPSULE:
+ radius = geoms_info.data[i_gb][0]
+
+ pen = collider_state.contact_data.penetration[i_c, i_b]
+ cpos = collider_state.contact_data.pos[i_c, i_b]
+ v1 = cpos + 0.5 * pen * normal
+ core_world = v1 - radius * normal
+ core_local = gu.qd_transform_by_quat(core_world - trans_convex, gu.qd_inv_quat(quat_convex))
+
+ collider_state.diff_contact_input.geom_a[i_b, i_c] = i_ga
+ collider_state.diff_contact_input.geom_b[i_b, i_c] = i_gb
+ collider_state.diff_contact_input.core_local[i_b, i_c] = core_local
+ collider_state.diff_contact_input.ref_id[i_b, i_c] = i_c
+ collider_state.diff_contact_input.valid[i_b, i_c] = 1
+
+
@qd.kernel(fastcache=True)
def func_narrow_phase_convex_specializations(
geoms_state: array_class.GeomsState,
diff --git a/genesis/engine/solvers/rigid/constraint/backward.py b/genesis/engine/solvers/rigid/constraint/backward.py
index e72cac1146..9e6b3cbc76 100644
--- a/genesis/engine/solvers/rigid/constraint/backward.py
+++ b/genesis/engine/solvers/rigid/constraint/backward.py
@@ -2,6 +2,7 @@
import genesis as gs
import genesis.utils.array_class as array_class
+import genesis.utils.geom as gu
@qd.func
@@ -52,6 +53,63 @@ def func_matvec_Ap(
constraint_state.bw_Ap[i_d, i_b] += constraint_state.jac[i_c, i_d, i_b] * jv
+@qd.func
+def func_solve_adjoint_u_cg_env(
+ i_b,
+ entities_info: array_class.EntitiesInfo,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ constraint_state: array_class.ConstraintState,
+ static_rigid_sim_config: qd.template(),
+):
+ """CG solve of A u = g for a single environment `i_b`.
+
+ `A = M + J^T diag(D) J` is applied implicitly by `func_matvec_Ap`, which
+ reads `rigid_global_info.mass_mat` directly and loops only over the active
+ constraints. So this also solves the unconstrained case `A = M` (no active
+ constraint -> empty J term).
+ """
+ n_dofs = constraint_state.bw_u.shape[0]
+
+ # r = g - A*0 = g ; p = r ; u = 0
+ for i_d in range(n_dofs):
+ constraint_state.bw_u[i_d, i_b] = 0.0
+ constraint_state.bw_r[i_d, i_b] = constraint_state.dL_dqacc[i_d, i_b]
+ constraint_state.bw_p[i_d, i_b] = constraint_state.bw_r[i_d, i_b]
+
+ for it in range(rigid_global_info.iterations[None]):
+ func_matvec_Ap(
+ entities_info=entities_info,
+ rigid_global_info=rigid_global_info,
+ constraint_state=constraint_state,
+ static_rigid_sim_config=static_rigid_sim_config,
+ i_b=i_b,
+ )
+
+ # alpha = (r,r)/(p,Ap)
+ num = gs.qd_float(0.0)
+ den = gs.qd_float(0.0)
+ for i_d in range(n_dofs):
+ num += constraint_state.bw_r[i_d, i_b] * constraint_state.bw_r[i_d, i_b]
+ den += constraint_state.bw_p[i_d, i_b] * constraint_state.bw_Ap[i_d, i_b]
+ alpha = num / qd.max(den, rigid_global_info.EPS[None])
+
+ # u += alpha p ; r -= alpha Ap
+ for i_d in range(n_dofs):
+ constraint_state.bw_u[i_d, i_b] += alpha * constraint_state.bw_p[i_d, i_b]
+ constraint_state.bw_r[i_d, i_b] -= alpha * constraint_state.bw_Ap[i_d, i_b]
+
+ if num < rigid_global_info.EPS[None]:
+ break
+
+ # beta = (r_new,r_new)/(r_old,r_old) ; p = r + beta p
+ num_new = gs.qd_float(0.0)
+ for i_d in range(n_dofs):
+ num_new += constraint_state.bw_r[i_d, i_b] * constraint_state.bw_r[i_d, i_b]
+ beta = num_new / qd.max(num, rigid_global_info.EPS[None])
+ for i_d in range(n_dofs):
+ constraint_state.bw_p[i_d, i_b] = constraint_state.bw_r[i_d, i_b] + beta * constraint_state.bw_p[i_d, i_b]
+
+
@qd.kernel
def kernel_solve_adjoint_u(
entities_info: array_class.EntitiesInfo,
@@ -77,76 +135,47 @@ def kernel_solve_adjoint_u(
constraint_state.bw_u[i_d, i_b] = 0.0
if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton):
- # Since we already have the Cholesky decomposition of A (= L * L^T), we can use it to solve A * u = g.
for i_b in range(_B):
- # z = L^{-1} g (forward substitution)
- # Save solution to bw_r
- for i_d in range(n_dofs):
- z = constraint_state.dL_dqacc[i_d, i_b]
- for j_d in range(i_d):
- z -= constraint_state.nt_H[i_b, i_d, j_d] * constraint_state.bw_r[j_d, i_b]
- z /= constraint_state.nt_H[i_b, i_d, i_d]
- constraint_state.bw_r[i_d, i_b] = z
-
- # u = L^{-T} z (back substitution)
- for i_d_ in range(n_dofs):
- i_d = n_dofs - 1 - i_d_
- u = constraint_state.bw_r[i_d, i_b]
- for j_d in range(i_d + 1, n_dofs):
- u -= constraint_state.nt_H[i_b, j_d, i_d] * constraint_state.bw_u[j_d, i_b]
- u /= constraint_state.nt_H[i_b, i_d, i_d]
- constraint_state.bw_u[i_d, i_b] = u
- else:
- # Use CG solver for solving A * u = g.
- # 2. Local buffers for solving A * u = g
- # Initialize r, p with dL_dqacc
- for i_d, i_b in qd.ndrange(n_dofs, _B):
- # Residual: g - A * 0 (u = 0)
- constraint_state.bw_r[i_d, i_b] = constraint_state.dL_dqacc[i_d, i_b]
- # Search direction: p = r
- constraint_state.bw_p[i_d, i_b] = constraint_state.bw_r[i_d, i_b]
-
- # 3. Solve A * u = g, parallelized over batch dimension
- for i_b in range(_B):
- # Compute Ap for the current search direction
- for it in range(static_rigid_sim_config.iterations):
- func_matvec_Ap(
+ if constraint_state.n_constraints[i_b] == 0:
+ # No active constraint: A = M. The forward's constrained-Hessian
+ # Cholesky `nt_H` is unreliable for these envs (the GPU tiled
+ # factorization skips n_c==0), so solve M u = g via CG, which
+ # reads `mass_mat` directly and never touches `nt_H`.
+ func_solve_adjoint_u_cg_env(
+ i_b=i_b,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
constraint_state=constraint_state,
static_rigid_sim_config=static_rigid_sim_config,
- i_b=i_b,
)
-
- # alpha = (r,r)/(p,Hp)
- num = gs.qd_float(0.0)
- den = gs.qd_float(0.0)
- for i_d in range(n_dofs):
- num += constraint_state.bw_r[i_d, i_b] * constraint_state.bw_r[i_d, i_b]
- den += constraint_state.bw_p[i_d, i_b] * constraint_state.bw_Ap[i_d, i_b]
- alpha = num / qd.max(den, rigid_global_info.EPS[None])
-
- # u += alpha p ; r -= alpha Hp
- for i_d in range(n_dofs):
- constraint_state.bw_u[i_d, i_b] += alpha * constraint_state.bw_p[i_d, i_b]
- constraint_state.bw_r[i_d, i_b] -= alpha * constraint_state.bw_Ap[i_d, i_b]
-
- # check tol (optional: per-batch)
- # TODO: Might need lower tolerance?
- if num < rigid_global_info.EPS[None]:
- break
-
- # beta = (r_new,r_new)/(r_old,r_old)
- num_new = gs.qd_float(0.0)
+ else:
+ # Reuse the forward's Cholesky decomposition A = L * L^T to solve A u = g.
+ # z = L^{-1} g (forward substitution); saved to bw_r
for i_d in range(n_dofs):
- num_new += constraint_state.bw_r[i_d, i_b] * constraint_state.bw_r[i_d, i_b]
- beta = num_new / qd.max(num, rigid_global_info.EPS[None])
+ z = constraint_state.dL_dqacc[i_d, i_b]
+ for j_d in range(i_d):
+ z -= constraint_state.nt_H[i_b, i_d, j_d] * constraint_state.bw_r[j_d, i_b]
+ z /= constraint_state.nt_H[i_b, i_d, i_d]
+ constraint_state.bw_r[i_d, i_b] = z
- # p = r + beta p
- for i_d in range(n_dofs):
- constraint_state.bw_p[i_d, i_b] = (
- constraint_state.bw_r[i_d, i_b] + beta * constraint_state.bw_p[i_d, i_b]
- )
+ # u = L^{-T} z (back substitution)
+ for i_d_ in range(n_dofs):
+ i_d = n_dofs - 1 - i_d_
+ u = constraint_state.bw_r[i_d, i_b]
+ for j_d in range(i_d + 1, n_dofs):
+ u -= constraint_state.nt_H[i_b, j_d, i_d] * constraint_state.bw_u[j_d, i_b]
+ u /= constraint_state.nt_H[i_b, i_d, i_d]
+ constraint_state.bw_u[i_d, i_b] = u
+ else:
+ # CG solver for A * u = g (parallelized over the batch dimension).
+ for i_b in range(_B):
+ func_solve_adjoint_u_cg_env(
+ i_b=i_b,
+ entities_info=entities_info,
+ rigid_global_info=rigid_global_info,
+ constraint_state=constraint_state,
+ static_rigid_sim_config=static_rigid_sim_config,
+ )
@qd.kernel
@@ -258,3 +287,437 @@ def kernel_compute_gradients(
val0 = -constraint_state.bw_u[i, i_b] * constraint_state.qacc[j, i_b]
val1 = -constraint_state.bw_u[j, i_b] * constraint_state.qacc[i, i_b]
constraint_state.dL_dM[i, j, i_b] += (val0 + val1) * 0.5 # symmetrize
+
+
+@qd.kernel(fastcache=True)
+def kernel_load_dL_dqacc_from_acc_grad(
+ dofs_state: array_class.DofsState,
+ constraint_state: array_class.ConstraintState,
+ static_rigid_sim_config: qd.template(),
+):
+ """Copy `dofs_state.acc.grad` into `constraint_state.dL_dqacc` (input buffer
+ consumed by `kernel_solve_adjoint_u`) and zero the source grad so the
+ downstream IFT-through-M path does not re-consume it.
+ """
+ _B = dofs_state.acc.shape[1]
+ n_dofs = dofs_state.acc.shape[0]
+ qd.loop_config(
+ name="kernel_load_dL_dqacc_from_acc_grad",
+ serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL),
+ )
+ for i_d, i_b in qd.ndrange(n_dofs, _B):
+ constraint_state.dL_dqacc[i_d, i_b] = dofs_state.acc.grad[i_d, i_b]
+ dofs_state.acc.grad[i_d, i_b] = gs.qd_float(0.0)
+
+
+@qd.kernel(fastcache=True)
+def kernel_accumulate_constraint_solver_grads(
+ dofs_state: array_class.DofsState,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ constraint_state: array_class.ConstraintState,
+ static_rigid_sim_config: qd.template(),
+):
+ """Fold the constraint-solver adjoint outputs into the autograd grad fields:
+ dofs_state.force.grad += constraint_state.dL_dforce
+ rigid_global_info.mass_mat.grad += constraint_state.dL_dM
+ """
+ _B = dofs_state.force.shape[1]
+ n_dofs = dofs_state.force.shape[0]
+ qd.loop_config(
+ name="kernel_accumulate_constraint_solver_grads",
+ serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL),
+ )
+ for i_d, i_b in qd.ndrange(n_dofs, _B):
+ dofs_state.force.grad[i_d, i_b] += constraint_state.dL_dforce[i_d, i_b]
+ for i, j, i_b in qd.ndrange(n_dofs, n_dofs, _B):
+ rigid_global_info.mass_mat.grad[i, j, i_b] += constraint_state.dL_dM[i, j, i_b]
+
+
+# ---------------------------------------------------------------------------
+# Manual reverses of the constraint-force inequality constraints (collision,
+# joint-limit). Shared conventions for the two kernels below.
+#
+# Why manual (not autograd): the constraint rows are built inside the forward
+# solver with a data-dependent count and ordering -- `n_con` is assigned by
+# atomic_add as active constraints are discovered -- which autograd cannot
+# differentiate cleanly (the row index is not a static, taped quantity).
+#
+# Upstream grads: `kernel_compute_gradients` populates, per constraint row
+# `n_con`, `constraint_state.dL_daref[n_con]` (dL/d aref), `dL_defc_D[n_con]`
+# (dL/d efc_D), and `dL_djac[n_con, i_d]` (dL/d jac). The collision reverse uses
+# `dL_djac`; the joint-limit reverse ignores it (its jac entries are piecewise-
+# constant +-1, so the sub-gradient is 0). Each kernel consumes these and
+# accumulates into its own differentiable inputs.
+#
+# `n_con` row layout: the forward adds constraints in the order frictionloss ->
+# collision -> joint-limit (see `add_inequality_constraints`). So with collision
+# on, the collision group occupies rows [0, 4 * n_contacts) (4 friction-pyramid
+# rows per contact) and joint-limit rows follow it. Each reverse re-walks the
+# same forward loop deterministically to recover its own `n_con` (no atomic_add,
+# no `n_constraints` reset): collision uses `n_con = i_col * 4 + i`; joint-limit
+# seeds its counter at `4 * n_contacts`.
+#
+# TODO: only collision + joint-limit are handled. If other constraint groups
+# (equality, frictionloss) are ever added to a differentiable scene, the `n_con`
+# offset in each reverse must be updated to account for them -- they are
+# currently assumed absent (not offset).
+# ---------------------------------------------------------------------------
+@qd.kernel(fastcache=True)
+def kernel_manual_add_joint_limit_constraints_bw(
+ links_info: array_class.LinksInfo,
+ joints_info: array_class.JointsInfo,
+ dofs_info: array_class.DofsInfo,
+ dofs_state: array_class.DofsState,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ constraint_state: array_class.ConstraintState,
+ collider_state: array_class.ColliderState,
+ static_rigid_sim_config: qd.template(),
+ enable_collision: qd.template(),
+):
+ """Manual reverse of `add_joint_limit_constraints`. See the section header
+ above for the shared `n_con` layout and upstream-grad conventions.
+
+ Accumulates into rigid_global_info.qpos.grad[i_q] and dofs_state.vel.grad[i_d].
+
+ Chain rule (per active joint, `pos_delta < 0`):
+
+ Forward:
+ pos_delta_min = qpos[i_q] - limit_lo
+ pos_delta_max = limit_hi - qpos[i_q]
+ pos_delta = min(pos_delta_min, pos_delta_max)
+ sign = +1 if pos_delta_min < pos_delta_max else -1
+ jac_qvel = sign * dofs_vel[i_d]
+ imp, aref = gu.imp_aref(sol_params, pos_delta, jac_qvel, pos_delta)
+ diag_raw = invweight * (1 - imp) / imp
+ diag = max(diag_raw, EPS)
+ efc_D = 1 / diag
+
+ d(pos_delta) / d(qpos) = sign (chosen branch of `min`)
+ d(jac_qvel) / d(vel) = sign
+
+ dL/d(imp) = ga * d(aref)/d(imp) + gD * d(efc_D)/d(imp)
+ ga = dL_daref[n_con], gD = dL_defc_D[n_con]
+
+ dL/d(pos_delta) = ga * d(aref)/d(pos_delta)|_direct
+ + dL/d(imp) * d(imp)/d(imp_x) * d(imp_x)/d(pos_delta)
+
+ dL/d(jac_qvel) = ga * d(aref)/d(jac_qvel) = -ga * b_coef
+
+ dL/d(qpos) += sign * dL/d(pos_delta)
+ dL/d(vel) += sign * dL/d(jac_qvel)
+ """
+ EPS = rigid_global_info.EPS[None]
+ _B = constraint_state.jac.shape[2]
+ n_links = links_info.root_idx.shape[0]
+
+ qd.loop_config(
+ name="kernel_manual_add_joint_limit_constraints_bw",
+ serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL),
+ )
+ for i_b in range(_B):
+ # Collision constraints (4 rows per contact) are added before joint
+ # limits in `add_inequality_constraints`, so offset the joint-limit row
+ # counter past them when collision is on.
+ n_con_counter = gs.qd_int(0)
+ if qd.static(enable_collision):
+ n_con_counter = gs.qd_int(collider_state.n_contacts[i_b] * 4)
+
+ for i_l in range(n_links):
+ I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l
+ for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]):
+ I_j = [i_j, i_b] if qd.static(static_rigid_sim_config.batch_joints_info) else i_j
+
+ if joints_info.type[I_j] == gs.JOINT_TYPE.REVOLUTE or joints_info.type[I_j] == gs.JOINT_TYPE.PRISMATIC:
+ i_q = joints_info.q_start[I_j]
+ i_d = joints_info.dof_start[I_j]
+ I_d = [i_d, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d
+
+ pos_delta_min = rigid_global_info.qpos[i_q, i_b] - dofs_info.limit[I_d][0]
+ pos_delta_max = dofs_info.limit[I_d][1] - rigid_global_info.qpos[i_q, i_b]
+ pos_delta = qd.min(pos_delta_min, pos_delta_max)
+
+ if pos_delta < 0:
+ n_con = n_con_counter
+ n_con_counter = n_con_counter + 1
+
+ # Replay forward intermediates (cheap, avoids stashing).
+ sign_pos = (pos_delta_min < pos_delta_max) * 2 - 1
+ sign_f = gs.qd_float(sign_pos)
+
+ sol_params = joints_info.sol_params[I_j]
+ timeconst = sol_params[0]
+ dampratio = sol_params[1]
+ dmin = sol_params[2]
+ dmax = sol_params[3]
+ width = sol_params[4]
+ mid = sol_params[5]
+ power = sol_params[6]
+
+ imp_x = qd.abs(pos_delta) / width
+ imp_a_coef = 1.0 / mid ** (power - 1.0)
+ imp_b_coef = 1.0 / (1.0 - mid) ** (power - 1.0)
+ imp_a = imp_a_coef * imp_x**power
+ imp_b = 1.0 - imp_b_coef * (1.0 - imp_x) ** power
+ imp_y = imp_a if imp_x < mid else imp_b
+ imp_raw = dmin + imp_y * (dmax - dmin)
+ imp_clamped = qd.math.clamp(imp_raw, dmin, dmax)
+ imp = dmax if imp_x > 1.0 else imp_clamped
+
+ b_coef = 2.0 / (dmax * timeconst)
+ k_coef = 1.0 / (dmax * dmax * timeconst * timeconst * dampratio * dampratio)
+
+ invweight = dofs_info.invweight[I_d]
+ diag_raw = invweight * (1.0 - imp) / imp
+ diag = qd.max(diag_raw, EPS)
+
+ # Upstream grads.
+ ga = constraint_state.dL_daref[n_con, i_b]
+ gD = constraint_state.dL_defc_D[n_con, i_b]
+
+ # --- Partials of forward outputs w.r.t. intermediates ---
+ # aref = -b_coef * jac_qvel - k_coef * imp * pos_delta
+ d_aref_d_imp = -k_coef * pos_delta
+ d_aref_d_jac_qvel = -b_coef
+ d_aref_d_pos_delta_direct = -k_coef * imp
+
+ # diag_raw = invweight*(1-imp)/imp ⇒ d(diag_raw)/d(imp) = -invweight/imp^2
+ # diag = max(diag_raw, EPS); efc_D = 1/diag
+ # d(efc_D)/d(imp) = -1/diag^2 · d(diag)/d(imp), 0 if clamped to EPS
+ d_diag_d_imp = gs.qd_float(0.0)
+ if diag_raw > EPS:
+ d_diag_d_imp = -invweight / (imp * imp)
+ d_efc_D_d_imp = -d_diag_d_imp / (diag * diag)
+
+ # d(imp)/d(imp_x): active only inside the smooth clamp band.
+ within_clamp = (imp_raw > dmin) and (imp_raw < dmax) and (imp_x <= 1.0)
+ d_imp_y_d_imp_x = gs.qd_float(0.0)
+ if imp_x < mid:
+ d_imp_y_d_imp_x = power * imp_a_coef * imp_x ** (power - 1.0)
+ else:
+ d_imp_y_d_imp_x = power * imp_b_coef * (1.0 - imp_x) ** (power - 1.0)
+ d_imp_d_imp_x = gs.qd_float(0.0)
+ if within_clamp:
+ d_imp_d_imp_x = (dmax - dmin) * d_imp_y_d_imp_x
+
+ # d(imp_x)/d(pos_delta) = sign(pos_delta)/width; pos_delta < 0 ⇒ -1/width
+ d_imp_x_d_pos_delta = -1.0 / width
+ d_imp_d_pos_delta = d_imp_d_imp_x * d_imp_x_d_pos_delta
+
+ # --- Combine ---
+ dL_d_imp = ga * d_aref_d_imp + gD * d_efc_D_d_imp
+ dL_d_pos_delta = ga * d_aref_d_pos_delta_direct + dL_d_imp * d_imp_d_pos_delta
+ dL_d_jac_qvel = ga * d_aref_d_jac_qvel
+
+ # --- Propagate ---
+ rigid_global_info.qpos.grad[i_q, i_b] += sign_f * dL_d_pos_delta
+ dofs_state.vel.grad[i_d, i_b] += sign_f * dL_d_jac_qvel
+
+
+@qd.kernel(fastcache=True)
+def kernel_manual_add_collision_constraints_bw(
+ links_info: array_class.LinksInfo,
+ links_state: array_class.LinksState,
+ dofs_state: array_class.DofsState,
+ constraint_state: array_class.ConstraintState,
+ collider_state: array_class.ColliderState,
+ rigid_global_info: array_class.RigidGlobalInfo,
+ static_rigid_sim_config: qd.template(),
+):
+ """Manual reverse of `add_collision_constraints`. See the section header
+ above for the shared `n_con` layout and upstream-grad conventions.
+
+ Produces the gradients w.r.t. the collision constraint's differentiable inputs:
+ collider_state.contact_data.{pos, normal, penetration}.grad (-> collider.backward)
+ dofs_state.{cdof_ang, cdof_vel, vel}.grad
+ links_state.root_COM.grad
+ (cdof / root_COM / vel grads feed the COM / forward-dynamics reverse chain;
+ contact_data grads feed `collider.backward`.)
+
+ Forward recap (per contact `i_col`, per friction-pyramid row `i` in 0..3):
+ d1, d2 = qd_orthogonals(normal); d = s_i * (d1 if i<2 else d2), s_i = 2*(i%2)-1
+ n = d * friction - normal
+ jac[n_con, i_d] = sum_chain (sign * vel_motion(i_d)) . n
+ vel_motion = cdof_vel - t_pos x cdof_ang, t_pos = contact_pos - root_COM[link]
+ jac_qvel = sum_chain jac[n_con, i_d] * dofs_vel[i_d]
+ imp, aref = imp_aref(sol_params, -penetration, jac_qvel, -penetration)
+ diag = (invweight + friction^2 invweight) * 2 friction^2 (1-imp)/imp ; efc_D = 1/diag
+ """
+ EPS = rigid_global_info.EPS[None]
+ _B = dofs_state.ctrl_mode.shape[1]
+ n_dofs = dofs_state.ctrl_mode.shape[0]
+ max_contact_pairs = collider_state.contact_data.link_a.shape[0]
+
+ qd.loop_config(
+ name="kernel_manual_add_collision_constraints_bw",
+ serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL),
+ )
+ for flat_idx in range(max_contact_pairs * _B):
+ i_b = flat_idx % _B
+ i_col = flat_idx // _B
+ if i_col < collider_state.n_contacts[i_b]:
+ link_a = collider_state.contact_data.link_a[i_col, i_b]
+ link_b = collider_state.contact_data.link_b[i_col, i_b]
+ contact_pos = collider_state.contact_data.pos[i_col, i_b]
+ normal = collider_state.contact_data.normal[i_col, i_b]
+ friction = collider_state.contact_data.friction[i_col, i_b]
+ sol_params = collider_state.contact_data.sol_params[i_col, i_b]
+ penetration = collider_state.contact_data.penetration[i_col, i_b]
+
+ link_a_maybe_batch = [link_a, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else link_a
+ invweight = links_info.invweight[link_a_maybe_batch][0]
+ if link_b > -1:
+ link_b_maybe_batch = [link_b, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else link_b
+ invweight = invweight + links_info.invweight[link_b_maybe_batch][0]
+
+ # --- forward intermediates of qd_orthogonals(normal) ---
+ # b_raw branches on |normal[1]| < 0.5; b = normalize(b_raw)
+ # d1 = b x normal, d2 = b
+ n0, n1, n2 = normal[0], normal[1], normal[2]
+ branch_a = qd.abs(n1) < 0.5
+ b_raw = gs.qd_vec3(0.0, 0.0, 0.0)
+ if branch_a:
+ b_raw = gs.qd_vec3(-n0 * n1, 1.0 - n1 * n1, -n2 * n1)
+ else:
+ b_raw = gs.qd_vec3(-n0 * n2, -n1 * n2, 1.0 - n2 * n2)
+ b_raw_norm = b_raw.norm()
+ b = b_raw / b_raw_norm
+ d1 = b.cross(normal)
+ d2 = b
+
+ sol_timeconst = sol_params[0]
+ sol_dampratio = sol_params[1]
+ sol_dmin = sol_params[2]
+ sol_dmax = sol_params[3]
+ sol_width = sol_params[4]
+ sol_mid = sol_params[5]
+ sol_power = sol_params[6]
+
+ neg_pen = -penetration
+ imp_x = qd.abs(neg_pen) / sol_width
+ # d(imp_x)/d(penetration) = -sign(neg_pen)/width
+ sign_neg = gs.qd_float(1.0) if neg_pen >= 0 else gs.qd_float(-1.0)
+ d_imp_x_d_pen = -sign_neg / sol_width
+
+ imp_a_coef = 1.0 / sol_mid ** (sol_power - 1.0)
+ imp_b_coef = 1.0 / (1.0 - sol_mid) ** (sol_power - 1.0)
+ imp_a = imp_a_coef * imp_x**sol_power
+ imp_b = 1.0 - imp_b_coef * (1.0 - imp_x) ** sol_power
+ imp_y = imp_a if imp_x < sol_mid else imp_b
+ imp_raw = sol_dmin + imp_y * (sol_dmax - sol_dmin)
+ imp_clamped = qd.math.clamp(imp_raw, sol_dmin, sol_dmax)
+ imp = sol_dmax if imp_x > 1.0 else imp_clamped
+
+ b_coef = 2.0 / (sol_dmax * sol_timeconst)
+ # k_coef matches gu.imp_aref's k = 1/(dmax^2 timeconst^2 dampratio^2)
+ k_coef = 1.0 / (sol_dmax * sol_dmax * sol_timeconst * sol_timeconst * sol_dampratio * sol_dampratio)
+
+ # diag = C0 * (1-imp)/imp, C0 = 2 friction^2 invweight (1 + friction^2)
+ C0 = (invweight + friction * friction * invweight) * 2.0 * friction * friction
+ diag_raw = C0 * (1.0 - imp) / imp
+ diag = qd.max(diag_raw, EPS)
+
+ within_clamp = (imp_raw > sol_dmin) and (imp_raw < sol_dmax) and (imp_x <= 1.0)
+ d_imp_y_d_imp_x = gs.qd_float(0.0)
+ if imp_x < sol_mid:
+ d_imp_y_d_imp_x = sol_power * imp_a_coef * imp_x ** (sol_power - 1.0)
+ else:
+ d_imp_y_d_imp_x = sol_power * imp_b_coef * (1.0 - imp_x) ** (sol_power - 1.0)
+ d_imp_d_imp_x = gs.qd_float(0.0)
+ if within_clamp:
+ d_imp_d_imp_x = (sol_dmax - sol_dmin) * d_imp_y_d_imp_x
+
+ d_diag_d_imp = gs.qd_float(0.0)
+ if diag_raw > EPS:
+ d_diag_d_imp = -C0 / (imp * imp)
+ d_efc_D_d_imp = -d_diag_d_imp / (diag * diag)
+
+ # Accumulators for this contact's differentiable inputs.
+ g_pos = gs.qd_vec3(0.0, 0.0, 0.0)
+ g_normal = gs.qd_vec3(0.0, 0.0, 0.0)
+ g_pen = gs.qd_float(0.0)
+ g_d1 = gs.qd_vec3(0.0, 0.0, 0.0)
+ g_d2 = gs.qd_vec3(0.0, 0.0, 0.0)
+
+ for i in range(4):
+ s_i = gs.qd_float(2 * (i % 2) - 1)
+ d = s_i * d1 if i < 2 else s_i * d2
+ n = d * friction - normal
+ n_con = i_col * 4 + i
+
+ ga = constraint_state.dL_daref[n_con, i_b]
+ gD = constraint_state.dL_defc_D[n_con, i_b]
+
+ # aref = -b_coef*jac_qvel + k_coef*imp*penetration (pos arg = -penetration)
+ d_aref_d_imp = k_coef * penetration
+ d_aref_d_pen_direct = k_coef * imp
+ d_aref_d_jac_qvel = -b_coef
+
+ dL_d_imp = ga * d_aref_d_imp + gD * d_efc_D_d_imp
+ dL_d_pen = ga * d_aref_d_pen_direct + dL_d_imp * d_imp_d_imp_x * d_imp_x_d_pen
+ g_pen += dL_d_pen
+ dL_d_jac_qvel = ga * d_aref_d_jac_qvel
+
+ # Reverse jac[n_con, i_d] over the kinematic chain.
+ dL_dn = gs.qd_vec3(0.0, 0.0, 0.0)
+ for i_ab in range(2):
+ sign = gs.qd_float(-1.0)
+ link = link_a
+ if i_ab == 1:
+ sign = gs.qd_float(1.0)
+ link = link_b
+ while link > -1:
+ link_mb = [link, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else link
+ for i_d_ in range(links_info.n_dofs[link_mb]):
+ i_d = links_info.dof_end[link_mb] - 1 - i_d_
+
+ cdof_ang = dofs_state.cdof_ang[i_d, i_b]
+ cdof_vel = dofs_state.cdof_vel[i_d, i_b]
+ t_pos = contact_pos - links_state.root_COM[link, i_b]
+ vel_motion = cdof_vel - t_pos.cross(cdof_ang)
+
+ jac_stored = constraint_state.jac[n_con, i_d, i_b]
+ g_jac = constraint_state.dL_djac[n_con, i_d, i_b] + dL_d_jac_qvel * dofs_state.vel[i_d, i_b]
+ dofs_state.vel.grad[i_d, i_b] += dL_d_jac_qvel * jac_stored
+
+ # jac_contrib = (sign * vel_motion) . n
+ dL_dn += g_jac * sign * vel_motion
+ g_vm = g_jac * sign * n # dL/d(vel_motion)
+
+ # vel_motion = cdof_vel - t_pos x cdof_ang
+ dofs_state.cdof_vel.grad[i_d, i_b] += g_vm
+ dofs_state.cdof_ang.grad[i_d, i_b] += t_pos.cross(g_vm)
+ dt = -(cdof_ang.cross(g_vm)) # dL/d(t_pos)
+ g_pos += dt
+ links_state.root_COM.grad[link, i_b] += -dt
+
+ link = links_info.parent_idx[link_mb]
+
+ # n = d*friction - normal
+ g_normal += -dL_dn
+ g_d = dL_dn * friction
+ if i < 2:
+ g_d1 += s_i * g_d
+ else:
+ g_d2 += s_i * g_d
+
+ # Reverse qd_orthogonals: d1 = b x normal, d2 = b, b = normalize(b_raw(normal)).
+ dL_db = g_d2 + normal.cross(g_d1)
+ g_normal += g_d1.cross(b)
+ # b = b_raw / |b_raw|
+ dL_db_raw = (dL_db - dL_db.dot(b) * b) / b_raw_norm
+ # b_raw(normal) branch Jacobian
+ if branch_a:
+ # b_raw = (-n0 n1, 1 - n1^2, -n2 n1)
+ g_normal[0] += dL_db_raw[0] * (-n1)
+ g_normal[1] += dL_db_raw[0] * (-n0) + dL_db_raw[1] * (-2.0 * n1) + dL_db_raw[2] * (-n2)
+ g_normal[2] += dL_db_raw[2] * (-n1)
+ else:
+ # b_raw = (-n0 n2, -n1 n2, 1 - n2^2)
+ g_normal[0] += dL_db_raw[0] * (-n2)
+ g_normal[1] += dL_db_raw[1] * (-n2)
+ g_normal[2] += dL_db_raw[0] * (-n0) + dL_db_raw[1] * (-n1) + dL_db_raw[2] * (-2.0 * n2)
+
+ for j in qd.static(range(3)):
+ collider_state.contact_data.pos.grad[i_col, i_b][j] = g_pos[j]
+ collider_state.contact_data.normal.grad[i_col, i_b][j] = g_normal[j]
+ collider_state.contact_data.penetration.grad[i_col, i_b] = g_pen
diff --git a/genesis/engine/solvers/rigid/constraint/noslip.py b/genesis/engine/solvers/rigid/constraint/noslip.py
index cf4c465a93..7df5fa94b0 100644
--- a/genesis/engine/solvers/rigid/constraint/noslip.py
+++ b/genesis/engine/solvers/rigid/constraint/noslip.py
@@ -44,11 +44,9 @@ def kernel_build_efc_AR_b(
i_b,
constraint_state.Mgrad,
constraint_state.Mgrad,
- None,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
- is_backward=False,
)
# TODO: For consistency with other usages, migrate to either the lower or upper variant
@@ -331,11 +329,9 @@ def kernel_dual_finish(
i_b=i_b,
vec=constraint_state.qfrc_constraint,
out=constraint_state.qacc,
- out_bw=None,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
- is_backward=False,
)
for i_d in range(n_dofs):
@@ -426,11 +422,9 @@ def compute_A_diag(
i_b,
constraint_state.Mgrad,
constraint_state.Mgrad,
- None,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
- is_backward=False,
)
# Ai = Ji * tmp
diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py
index 0283f7c206..c1397b7653 100644
--- a/genesis/engine/solvers/rigid/constraint/solver.py
+++ b/genesis/engine/solvers/rigid/constraint/solver.py
@@ -442,12 +442,13 @@ def delete_weld_constraint(self, link1_idx, link2_idx, envs_idx=None):
self._solver._static_rigid_sim_config,
)
- def backward(self, dL_dqacc):
+ def backward(self):
if not self._solver._requires_grad:
gs.raise_exception("Please set `requires_grad` to True in SimOptions to enable differentiable mode.")
- # Copy upstream gradients
- self.constraint_state.dL_dqacc.from_numpy(dL_dqacc)
+ # Upstream gradient `dL_dqacc` is expected to be pre-populated in
+ # `constraint_state.dL_dqacc` by the caller (see
+ # `kernel_load_dL_dqacc_from_acc_grad`).
# 1. We first need to find a solution to A^T * u = g system.
backward_constraint_solver.kernel_solve_adjoint_u(
@@ -3522,11 +3523,9 @@ def func_update_gradient_batch(
i_b,
constraint_state.grad,
constraint_state.Mgrad,
- None,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
- is_backward=False,
)
if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton):
@@ -3569,11 +3568,9 @@ def func_update_gradient_tiled(
i_b,
constraint_state.grad,
constraint_state.Mgrad,
- None,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
- is_backward=False,
)
if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton):
diff --git a/genesis/engine/solvers/rigid/constraint/solver_island.py b/genesis/engine/solvers/rigid/constraint/solver_island.py
index 3a300cda20..b18a4ff833 100644
--- a/genesis/engine/solvers/rigid/constraint/solver_island.py
+++ b/genesis/engine/solvers/rigid/constraint/solver_island.py
@@ -987,11 +987,9 @@ def _func_update_gradient(
i_b,
self.grad,
self.Mgrad,
- None,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=self._solver._static_rigid_sim_config,
- is_backward=False,
)
for i_e in range(self._solver.n_entities):
self._solver.mass_mat_mask[i_e, i_b] = True
diff --git a/genesis/engine/solvers/rigid/rigid_solver.py b/genesis/engine/solvers/rigid/rigid_solver.py
index cf61c94ef2..85f700de74 100644
--- a/genesis/engine/solvers/rigid/rigid_solver.py
+++ b/genesis/engine/solvers/rigid/rigid_solver.py
@@ -29,6 +29,12 @@
from ..kinematic_solver import KinematicSolver
from .collider import Collider
from .constraint import ConstraintSolver, ConstraintSolverIsland
+from .constraint.backward import (
+ kernel_manual_add_joint_limit_constraints_bw,
+ kernel_manual_add_collision_constraints_bw,
+ kernel_load_dL_dqacc_from_acc_grad,
+ kernel_accumulate_constraint_solver_grads,
+)
from .abd.misc import (
func_add_safe_backward,
func_apply_coupling_force,
@@ -91,7 +97,10 @@
kernel_update_all_verts,
kernel_update_geom_aabbs,
kernel_update_vgeoms,
+ kernel_COM_links_replay,
kernel_update_cartesian_space,
+ kernel_forward_kinematics_replay,
+ kernel_update_geoms_replay,
)
from .abd.forward_dynamics import (
func_actuation,
@@ -169,6 +178,12 @@
kernel_prepare_backward_substep,
kernel_begin_backward_substep,
kernel_copy_acc,
+ kernel_copy_next_to_curr_no_check,
+)
+from .abd.manual_bw import (
+ kernel_manual_compute_qacc_bw,
+ kernel_manual_forward_kinematics_bw,
+ kernel_manual_forward_velocity_bw,
)
if TYPE_CHECKING:
@@ -1027,7 +1042,6 @@ def substep(self, f):
self.dofs_state,
self._rigid_global_info,
self._static_rigid_sim_config,
- self._is_backward,
)
else:
self._func_constraint_force()
@@ -1089,6 +1103,12 @@ def check_errno(self):
gs.raise_exception("Invalid accelerations causing 'nan'. Please decrease Rigid simulation timestep.")
if errno & array_class.ErrorCode.OVERFLOW_HIBERNATION_ISLANDS:
gs.raise_exception("Contact island buffer overflow. Please increase RigidOptions 'max_collision_pairs'.")
+ if errno & array_class.ErrorCode.MANUAL_BW_UNIMPLEMENTED:
+ gs.raise_exception(
+ "Encountered a configuration (e.g. hibernation) that the manual backward kernels "
+ "do not support. Extend the corresponding `kernel_manual_*_bw` in "
+ "`genesis/engine/solvers/rigid/abd/manual_bw.py`."
+ )
def _kernel_detect_collision(self):
self.collider.clear()
@@ -1302,6 +1322,105 @@ def reset_grad(self):
qd_zero_grad(self.geoms_state_adjoint_cache)
qd_zero_grad(self._rigid_adjoint_cache)
+ def _update_cartesian_grad(self, envs_idx):
+ """Forward-replay the post-integrate cartesian-space update
+ (FK -> COM -> geom poses -> velocity) under `is_backward=True`, then
+ reverse it: velocity and FK are reversed manually (`kernel_manual_*_bw`),
+ while COM and the link->geom transform are reversed by Quadrants autograd
+ (`.grad`). Shared by both the post-FK reverse and the first-substep
+ initial-state reverse in `substep_pre_coupling_grad`.
+ """
+ # Forward replay in dependency order (FK -> COM -> geoms -> velocity).
+ kernel_forward_kinematics_replay(
+ envs_idx=envs_idx,
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_state=self.joints_state,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ dofs_info=self.dofs_info,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ is_backward=True,
+ )
+ kernel_COM_links_replay(
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_state=self.joints_state,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ dofs_info=self.dofs_info,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ is_backward=True,
+ )
+ kernel_update_geoms_replay(
+ entities_info=self.entities_info,
+ geoms_state=self.geoms_state,
+ geoms_info=self.geoms_info,
+ links_state=self.links_state,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ is_backward=True,
+ )
+ kernel_forward_velocity(
+ envs_idx=envs_idx,
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ is_backward=True,
+ )
+
+ # Reverse in opposite order (vel.bw -> COM.bw -> geoms.bw -> FK.bw).
+ kernel_manual_forward_velocity_bw(
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ errno=self._errno,
+ )
+ kernel_COM_links_replay.grad(
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_state=self.joints_state,
+ joints_info=self.joints_info,
+ dofs_state=self.dofs_state,
+ dofs_info=self.dofs_info,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ is_backward=True,
+ )
+ kernel_update_geoms_replay.grad(
+ entities_info=self.entities_info,
+ geoms_state=self.geoms_state,
+ geoms_info=self.geoms_info,
+ links_state=self.links_state,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ is_backward=True,
+ )
+ kernel_manual_forward_kinematics_bw(
+ links_state=self.links_state,
+ links_info=self.links_info,
+ joints_state=self.joints_state,
+ joints_info=self.joints_info,
+ dofs_info=self.dofs_info,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ errno=self._errno,
+ )
+
def substep_pre_coupling_grad(self, f):
# Change to backward mode
self._is_backward = True
@@ -1328,36 +1447,17 @@ def substep_pre_coupling_grad(self, f):
static_rigid_sim_config=self._static_rigid_sim_config,
)
self.substep(f)
-
# =================== Backward substep ======================
envs_idx = self._scene._sanitize_envs_idx(None)
if not self._enable_mujoco_compatibility:
- kernel_forward_velocity.grad(
- envs_idx=envs_idx,
- links_state=self.links_state,
- links_info=self.links_info,
- joints_info=self.joints_info,
+ # The FK backward below builds its Jacobian at the post-integrate qpos/vel, so
+ # copy the integrator's `_next` outputs into the current slots first.
+ kernel_copy_next_to_curr_no_check(
dofs_state=self.dofs_state,
- entities_info=self.entities_info,
rigid_global_info=self._rigid_global_info,
static_rigid_sim_config=self._static_rigid_sim_config,
- is_backward=True,
- )
- kernel_update_cartesian_space.grad(
- links_state=self.links_state,
- links_info=self.links_info,
- joints_state=self.joints_state,
- joints_info=self.joints_info,
- dofs_state=self.dofs_state,
- dofs_info=self.dofs_info,
- geoms_state=self.geoms_state,
- geoms_info=self.geoms_info,
- entities_info=self.entities_info,
- rigid_global_info=self._rigid_global_info,
- static_rigid_sim_config=self._static_rigid_sim_config,
- force_update_fixed_geoms=False,
- is_backward=True,
)
+ self._update_cartesian_grad(envs_idx)
is_grad_valid = kernel_begin_backward_substep(
f=f,
@@ -1382,40 +1482,104 @@ def substep_pre_coupling_grad(self, f):
gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}")
kernel_step_2.grad(
- dofs_state=self.dofs_state,
- dofs_info=self.dofs_info,
- links_info=self.links_info,
- links_state=self.links_state,
- joints_info=self.joints_info,
- joints_state=self.joints_state,
- entities_state=self.entities_state,
- entities_info=self.entities_info,
- geoms_info=self.geoms_info,
- geoms_state=self.geoms_state,
- collider_state=self.collider._collider_state,
- rigid_global_info=self._rigid_global_info,
- static_rigid_sim_config=self._static_rigid_sim_config,
- contact_island_state=self.constraint_solver.contact_island.contact_island_state,
- is_backward=True,
- errno=self._errno,
+ self.dofs_state,
+ self.dofs_info,
+ self.links_info,
+ self.links_state,
+ self.joints_info,
+ self.joints_state,
+ self.entities_state,
+ self.entities_info,
+ self.geoms_info,
+ self.geoms_state,
+ self.collider._collider_state,
+ self._rigid_global_info,
+ self._static_rigid_sim_config,
+ self.constraint_solver.contact_island.contact_island_state,
+ self._is_backward,
+ self._errno,
)
- # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel,
- # which is prohibited (https://docs.taichi-lang.org/docs/differentiable_programming#global-data-access-rules).
- # In [kernel_forward_dynamics], we read [acc] in [func_update_acc] and overwrite it in [kernel_compute_qacc].
- # As [kenrel_compute_qacc] is called at the end of [kernel_forward_dynamics], we first backpropagate through
- # [kernel_compute_qacc] and then restore the original [acc] from the adjoint cache. This copy operation
- # cannot be merged with [kernel_compute_qacc.grad] because .grad function itself is a standalone kernel.
- # We could possibly merge this small kernel later if (1) .grad function is regarded as a function instead of a
- # kernel, (2) we add another variable to store the new [acc] from [kernel_compute_qacc] and thus can avoid
- # the data access violation. However, both of these require major changes.
- kernel_compute_qacc.grad(
- dofs_state=self.dofs_state,
- entities_info=self.entities_info,
- rigid_global_info=self._rigid_global_info,
- static_rigid_sim_config=self._static_rigid_sim_config,
- is_backward=True,
- )
+ # Two backward paths for `force -> acc`:
+ # (A) Unconstrained: `kernel_manual_compute_qacc_bw` does the IFT through M
+ # (writes force.grad + mass_mat.grad).
+ # (B) Constrained (active joint limits / collision): the constraint solve
+ # overwrites acc, so the IFT through M alone is wrong. Instead drive
+ # `constraint_solver.backward` (adjoint KKT -> dL_dM/djac/daref/defc_D/
+ # dforce) + the per-constraint manual reverses below.
+ _has_collision = (not self._disable_constraint) and self._enable_collision
+ _has_joint_limit = (not self._disable_constraint) and self._options.enable_joint_limit
+ _constrained_bw = _has_collision or _has_joint_limit
+ if _constrained_bw:
+ # Reject equality / frictionloss constraints: forward orders rows as
+ # equality -> frictionloss -> collision -> joint-limit, but the
+ # manual reverses below only re-walk the last two groups.
+ # TODO: implement manual reverses for equality and frictionloss
+ # rows so they can participate in differentiable scenes; until then
+ # we reject host-side instead of producing a wrong gradient.
+ cs = self.constraint_solver.constraint_state
+ n_eq_max = int(qd_to_numpy(cs.n_constraints_equality).max())
+ n_fric_max = int(qd_to_numpy(cs.n_constraints_frictionloss).max())
+ if n_eq_max > 0 or n_fric_max > 0:
+ gs.raise_exception(
+ "Differentiable rigid backward does not support equality or frictionloss "
+ f"constraints (found n_constraints_equality={n_eq_max}, "
+ f"n_constraints_frictionloss={n_fric_max}). Disable them in a differentiable scene."
+ )
+
+ # Backward pass for the constraint solver
+ kernel_load_dL_dqacc_from_acc_grad(
+ dofs_state=self.dofs_state,
+ constraint_state=self.constraint_solver.constraint_state,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
+ self.constraint_solver.backward()
+ kernel_accumulate_constraint_solver_grads(
+ dofs_state=self.dofs_state,
+ rigid_global_info=self._rigid_global_info,
+ constraint_state=self.constraint_solver.constraint_state,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
+
+ if _has_collision:
+ # Manual reverse of the collision constraint rows -> contact-data
+ # grads, then differentiate the narrow-phase (diff GJK) through
+ # collider.backward into geom pose grads.
+ collider_state = self.collider._collider_state
+ collider_state.contact_data.pos.grad.fill(0.0)
+ collider_state.contact_data.normal.grad.fill(0.0)
+ collider_state.contact_data.penetration.grad.fill(0.0)
+ kernel_manual_add_collision_constraints_bw(
+ links_info=self.links_info,
+ links_state=self.links_state,
+ dofs_state=self.dofs_state,
+ constraint_state=self.constraint_solver.constraint_state,
+ collider_state=collider_state,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
+ self.collider.backward_narrowphase()
+
+ if _has_joint_limit:
+ kernel_manual_add_joint_limit_constraints_bw(
+ links_info=self.links_info,
+ joints_info=self.joints_info,
+ dofs_info=self.dofs_info,
+ dofs_state=self.dofs_state,
+ rigid_global_info=self._rigid_global_info,
+ constraint_state=self.constraint_solver.constraint_state,
+ collider_state=self.collider._collider_state,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ enable_collision=_has_collision,
+ )
+ else:
+ # Manual backward for `func_compute_qacc` via the Implicit Function Theorem.
+ kernel_manual_compute_qacc_bw(
+ dofs_state=self.dofs_state,
+ entities_info=self.entities_info,
+ rigid_global_info=self._rigid_global_info,
+ static_rigid_sim_config=self._static_rigid_sim_config,
+ )
kernel_copy_acc(
f=f,
dofs_state=self.dofs_state,
@@ -1440,32 +1604,7 @@ def substep_pre_coupling_grad(self, f):
# If it was the very first substep, we need to backpropagate through the initial update of the cartesian space
if self._enable_mujoco_compatibility or self._sim.cur_substep_global == 0:
- kernel_forward_velocity.grad(
- envs_idx=envs_idx,
- links_state=self.links_state,
- links_info=self.links_info,
- joints_info=self.joints_info,
- dofs_state=self.dofs_state,
- entities_info=self.entities_info,
- rigid_global_info=self._rigid_global_info,
- static_rigid_sim_config=self._static_rigid_sim_config,
- is_backward=True,
- )
- kernel_update_cartesian_space.grad(
- links_state=self.links_state,
- links_info=self.links_info,
- joints_state=self.joints_state,
- joints_info=self.joints_info,
- dofs_state=self.dofs_state,
- dofs_info=self.dofs_info,
- geoms_state=self.geoms_state,
- geoms_info=self.geoms_info,
- entities_info=self.entities_info,
- rigid_global_info=self._rigid_global_info,
- static_rigid_sim_config=self._static_rigid_sim_config,
- force_update_fixed_geoms=False,
- is_backward=True,
- )
+ self._update_cartesian_grad(envs_idx)
# Change back to forward mode
self._is_backward = False
@@ -1481,7 +1620,6 @@ def substep_post_coupling(self, f):
dofs_state=self.dofs_state,
rigid_global_info=self._rigid_global_info,
static_rigid_sim_config=self._static_rigid_sim_config,
- is_backward=self._is_backward,
)
kernel_step_2(
dofs_state=self.dofs_state,
@@ -1688,9 +1826,11 @@ def save_ckpt(self, ckpt_name):
if ckpt_name not in self._ckpt:
self._ckpt[ckpt_name] = dict()
- self._ckpt[ckpt_name]["qpos"] = qd_to_numpy(self._rigid_adjoint_cache.qpos)
- self._ckpt[ckpt_name]["dofs_vel"] = qd_to_numpy(self._rigid_adjoint_cache.dofs_vel)
- self._ckpt[ckpt_name]["dofs_acc"] = qd_to_numpy(self._rigid_adjoint_cache.dofs_acc)
+ # `copy=True` required: with the zerocopy backend `qd_to_numpy` returns a
+ # view, so later substeps would overwrite this ckpt's buffer in place.
+ self._ckpt[ckpt_name]["qpos"] = qd_to_numpy(self._rigid_adjoint_cache.qpos, copy=True)
+ self._ckpt[ckpt_name]["dofs_vel"] = qd_to_numpy(self._rigid_adjoint_cache.dofs_vel, copy=True)
+ self._ckpt[ckpt_name]["dofs_acc"] = qd_to_numpy(self._rigid_adjoint_cache.dofs_acc, copy=True)
for entity in self._entities:
entity.save_ckpt(ckpt_name)
diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py
index 5baad880da..564f073ec5 100644
--- a/genesis/utils/array_class.py
+++ b/genesis/utils/array_class.py
@@ -86,6 +86,7 @@ class ErrorCode(IntEnum):
OVERFLOW_HIBERNATION_ISLANDS = 0b00000000000000000000000000000100
INVALID_FORCE_NAN = 0b00000000000000000000000000001000
INVALID_ACC_NAN = 0b00000000000000000000000000010000
+ MANUAL_BW_UNIMPLEMENTED = 0b00000000000000000000000000100000
# =========================================== RigidGlobalInfo ===========================================
@@ -108,7 +109,6 @@ class RigidGlobalInfo:
geoms_init_AABB: qd.Tensor
mass_mat: qd.Tensor
mass_mat_L: qd.Tensor
- mass_mat_L_bw: qd.Tensor
mass_mat_D_inv: qd.Tensor
mass_mat_mask: qd.Tensor
meaninertia: qd.Tensor
@@ -138,11 +138,6 @@ def get_rigid_global_info(solver, kinematic_only):
f"Mass matrix shape (n_dofs={solver.n_dofs_}, n_dofs={solver.n_dofs_}, n_envs={_B}) is too large."
)
requires_grad = solver._requires_grad
- mass_mat_shape_bw = maybe_shape((2, *mass_mat_shape), requires_grad)
- if math.prod(mass_mat_shape_bw) > np.iinfo(np.int32).max:
- gs.raise_exception(
- f"Mass matrix buffer shape (2, n_dofs={solver.n_dofs_}, n_dofs={solver.n_dofs_}, n_envs={_B}) is too large."
- )
# Flip mass_mat from canonical (n_dofs(i_d1), n_dofs(i_d2), _B) -> physical (_B, n_dofs(i_d2), n_dofs(i_d1)) via
# layout=(2, 1, 0): i_d1 becomes innermost / stride-1, which coalesces consumer kernels whose lanes stride i_d1
@@ -175,7 +170,6 @@ def get_rigid_global_info(solver, kinematic_only):
geoms_init_AABB=V_VEC(3, dtype=gs.qd_float, shape=()),
mass_mat=V(dtype=gs.qd_float, shape=()),
mass_mat_L=V(dtype=gs.qd_float, shape=()),
- mass_mat_L_bw=V(dtype=gs.qd_float, shape=()),
mass_mat_D_inv=V(dtype=gs.qd_float, shape=()),
mass_mat_mask=V(dtype=gs.qd_bool, shape=()),
mass_parent_mask=V(dtype=gs.qd_float, shape=()),
@@ -210,7 +204,6 @@ def get_rigid_global_info(solver, kinematic_only):
geoms_init_AABB=V_VEC(3, dtype=gs.qd_float, shape=(solver.n_geoms_, 8)),
mass_mat=V(dtype=gs.qd_float, shape=mass_mat_shape, layout=mass_mat_layout, needs_grad=requires_grad),
mass_mat_L=V(dtype=gs.qd_float, shape=mass_mat_shape, needs_grad=requires_grad),
- mass_mat_L_bw=V(dtype=gs.qd_float, shape=mass_mat_shape_bw, needs_grad=requires_grad),
mass_mat_D_inv=V(dtype=gs.qd_float, shape=(solver.n_dofs_, _B), needs_grad=requires_grad),
mass_mat_mask=V(dtype=gs.qd_bool, shape=(solver.n_entities_, _B)),
mass_parent_mask=V(dtype=gs.qd_float, shape=(solver.n_dofs_, solver.n_dofs_)),
@@ -496,6 +489,9 @@ class DiffContactInput:
# Local positions of the 1 vertex from the two geometries that define the support point for the face above
w_local_pos1: qd.Tensor
w_local_pos2: qd.Tensor
+ # Plane-convex contacts only: the convex support "core" (box vertex / sphere
+ # center / capsule nearest endpoint) in the convex geom's local frame.
+ core_local: qd.Tensor
# Reference id of the contact point, which is needed for the backward pass
ref_id: qd.Tensor
# Flag whether the contact data can be computed in numerically stable way in both the forward and backward passes
@@ -518,6 +514,7 @@ def get_diff_contact_input(_B, max_contacts_per_pair, is_active, requires_grad=F
local_pos2_c=V_VEC(3, dtype=gs.qd_float, shape=shape),
w_local_pos1=V_VEC(3, dtype=gs.qd_float, shape=shape),
w_local_pos2=V_VEC(3, dtype=gs.qd_float, shape=shape),
+ core_local=V_VEC(3, dtype=gs.qd_float, shape=shape),
ref_id=V(dtype=gs.qd_int, shape=shape),
valid=V(dtype=gs.qd_int, shape=shape),
ref_penetration=V(dtype=gs.qd_float, shape=shape, needs_grad=True),
diff --git a/genesis/utils/misc.py b/genesis/utils/misc.py
index c4cf160bda..034da71ac1 100644
--- a/genesis/utils/misc.py
+++ b/genesis/utils/misc.py
@@ -561,7 +561,7 @@ def qd_to_torch(
except AttributeError:
try:
tc = value.to_torch(copy=False)
- except (ValueError, RuntimeError):
+ except (ValueError, RuntimeError, TypeError):
if copy is False:
raise
tensor = _maybe_transpose(value.to_torch(), value, transpose)
diff --git a/tests/test_grad.py b/tests/test_grad.py
deleted file mode 100644
index d6276fd6e8..0000000000
--- a/tests/test_grad.py
+++ /dev/null
@@ -1,546 +0,0 @@
-import numpy as np
-import pytest
-import torch
-
-import genesis as gs
-from genesis.utils.geom import R_to_quat
-from genesis.utils.misc import qd_to_torch, qd_to_numpy, tensor_to_array
-from genesis.utils import set_random_seed
-
-from .utils import assert_allclose
-
-
-@pytest.mark.required
-@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
-def test_differentiable_push(show_viewer):
- HORIZON = 10
-
- scene = gs.Scene(
- sim_options=gs.options.SimOptions(
- dt=2e-3,
- substeps=10,
- requires_grad=True,
- ),
- mpm_options=gs.options.MPMOptions(
- lower_bound=(0.0, -1.0, 0.0),
- upper_bound=(1.0, 1.0, 0.55),
- ),
- viewer_options=gs.options.ViewerOptions(
- camera_pos=(2.5, -0.15, 2.42),
- camera_lookat=(0.5, 0.5, 0.1),
- ),
- show_viewer=show_viewer,
- )
-
- plane = scene.add_entity(
- gs.morphs.URDF(
- file="urdf/plane/plane.urdf",
- fixed=True,
- )
- )
- stick = scene.add_entity(
- morph=gs.morphs.Mesh(
- file="meshes/stirrer.obj",
- scale=0.6,
- pos=(0.5, 0.5, 0.05),
- euler=(90.0, 0.0, 0.0),
- ),
- material=gs.materials.Tool(
- friction=8.0,
- ),
- )
- obj = scene.add_entity(
- morph=gs.morphs.Box(
- lower=(0.2, 0.1, 0.05),
- upper=(0.4, 0.3, 0.15),
- ),
- material=gs.materials.MPM.Elastic(
- rho=500,
- ),
- )
- scene.build(n_envs=2)
-
- init_pos = gs.tensor([[0.3, 0.1, 0.28], [0.3, 0.1, 0.5]], requires_grad=True)
- stick.set_position(init_pos)
- pos_obj_init = gs.tensor([0.3, 0.3, 0.1], requires_grad=True)
- obj.set_position(pos_obj_init)
- v_obj_init = gs.tensor([0.0, -1.0, 0.0], requires_grad=True)
- obj.set_velocity(v_obj_init)
- goal = gs.tensor([0.5, 0.8, 0.05])
-
- loss = 0.0
- v_list = []
- for i in range(HORIZON):
- v_i = gs.tensor([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]], requires_grad=True)
- stick.set_velocity(vel=v_i)
- v_list.append(v_i)
-
- scene.step()
-
- if i == HORIZON // 2:
- mpm_particles = scene.get_state().solvers_state[scene.solvers.index(scene.mpm_solver)]
- loss += torch.pow(mpm_particles.pos[mpm_particles.active == 1] - goal, 2).sum()
-
- if i == HORIZON - 2:
- state = obj.get_state()
- loss += torch.pow(state.pos - goal, 2).sum()
- loss.backward()
-
- # TODO: It would be great to compare the gradient to its analytical or numerical value.
- for v_i in v_list[:-1]:
- assert (v_i.grad.abs() > gs.EPS).any()
- assert (v_list[-1].grad.abs() < gs.EPS).all()
-
-
-@pytest.mark.required
-@pytest.mark.precision("64")
-@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
-def test_diff_contact():
- RTOL = 1e-4
-
- scene = gs.Scene(
- sim_options=gs.options.SimOptions(
- dt=0.01,
- # Turn on differentiable mode
- requires_grad=True,
- ),
- show_viewer=False,
- )
-
- box_size = 0.25
- box_spacing = box_size
- vec_one = np.array([1.0, 1.0, 1.0])
- box_pos_offset = (0.0, 0.0, 0.0) + 0.5 * box_size * vec_one
-
- box0 = scene.add_entity(
- gs.morphs.Box(size=box_size * vec_one, pos=box_pos_offset),
- )
- box1 = scene.add_entity(
- gs.morphs.Box(size=box_size * vec_one, pos=box_pos_offset + 0.8 * box_spacing * np.array([0, 0, 1])),
- )
- scene.build()
- solver = scene.sim.rigid_solver
- collider = solver.collider
-
- # Set up initial configuration
- x_ang, y_ang, z_ang = 3.0, 3.0, 3.0
- box1.set_quat(R_to_quat(gs.euler_to_R([np.deg2rad(x_ang), np.deg2rad(y_ang), np.deg2rad(z_ang)])))
-
- box0_init_pos = box0.get_pos().clone()
- box1_init_pos = box1.get_pos().clone()
- box0_init_quat = box0.get_quat().clone()
- box1_init_quat = box1.get_quat().clone()
-
- ### Compute the initial loss and compute gradients using differentiable contact detection
- # Detect contact
- collider.detection()
-
- # Get contact outputs and their grads
- contacts = collider.get_contacts(as_tensor=True, to_torch=True, keep_batch_dim=True)
- normal = contacts["normal"].requires_grad_()
- position = contacts["position"].requires_grad_()
- penetration = contacts["penetration"].requires_grad_()
-
- loss = ((normal * position).sum(dim=-1) * penetration).sum()
- dL_dnormal = torch.autograd.grad(loss, normal, retain_graph=True)[0]
- dL_dposition = torch.autograd.grad(loss, position, retain_graph=True)[0]
- dL_dpenetration = torch.autograd.grad(loss, penetration)[0]
-
- # Compute analytical gradients of the geoms position and quaternion
- collider.backward(dL_dposition, dL_dnormal, dL_dpenetration)
- dL_dpos = qd_to_torch(solver.geoms_state.pos.grad)
- dL_dquat = qd_to_torch(solver.geoms_state.quat.grad)
-
- ### Compute directional derivatives along random directions
- FD_EPS = 1e-5
- TRIALS = 100
-
- def compute_dL_error(dL_dx, x_type):
- dL_error_rel = 0.0
-
- box0_input_pos = box0_init_pos
- box1_input_pos = box1_init_pos
- box0_input_quat = box0_init_quat
- box1_input_quat = box1_init_quat
-
- for _ in range(TRIALS):
- rand_dx = torch.randn_like(dL_dx)
- rand_dx = torch.nn.functional.normalize(rand_dx, dim=-1)
-
- dL = (rand_dx * dL_dx).sum()
-
- lossPs = []
- for sign in (1, -1):
- # Compute query point
- if x_type == "pos":
- box0_input_pos = box0_init_pos + sign * rand_dx[0, 0] * FD_EPS
- box1_input_pos = box1_init_pos + sign * rand_dx[1, 0] * FD_EPS
- else:
- # FIXME: The quaternion should be normalized
- box0_input_quat = box0_init_quat + sign * rand_dx[0, 0] * FD_EPS
- box1_input_quat = box1_init_quat + sign * rand_dx[1, 0] * FD_EPS
-
- # Update box positions
- box0.set_pos(box0_input_pos)
- box1.set_pos(box1_input_pos)
- box0.set_quat(box0_input_quat)
- box1.set_quat(box1_input_quat)
-
- # Re-detect contact.
- # We need to manually reset the contact counter as we are not running the whole sim step.
- collider._collider_state.n_contacts.fill(0)
- collider.detection()
- contacts = collider.get_contacts(as_tensor=True, to_torch=True, keep_batch_dim=True)
- normal, position, penetration = contacts["normal"], contacts["position"], contacts["penetration"]
-
- # Compute loss
- loss = ((normal * position).sum(dim=-1) * penetration).sum()
- lossPs.append(loss)
-
- dL_fd = (lossPs[0] - lossPs[1]) / (2 * FD_EPS)
- dL_error_rel += (dL - dL_fd).abs() / max(dL.abs(), dL_fd.abs(), gs.EPS)
-
- dL_error_rel /= TRIALS
- return dL_error_rel
-
- dL_dpos_error_rel = compute_dL_error(dL_dpos, "pos")
- assert_allclose(dL_dpos_error_rel, 0.0, atol=RTOL)
- dL_dquat_error_rel = compute_dL_error(dL_dquat, "quat")
- assert_allclose(dL_dquat_error_rel, 0.0, atol=RTOL)
-
-
-# We need to use 64-bit precision for this test because we need to use sufficiently small perturbation to get reliable
-# gradient estimates through finite difference method. This small perturbation is not supported by 32-bit precision in
-# stable way.
-@pytest.mark.required
-@pytest.mark.precision("64")
-@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
-def test_diff_solver(monkeypatch):
- from genesis.engine.solvers.rigid.constraint.solver import func_solve_init, func_solve_body
- from genesis.engine.solvers.rigid.rigid_solver import kernel_step_1
-
- RTOL = 1e-4
-
- scene = gs.Scene(
- sim_options=gs.options.SimOptions(
- dt=0.01,
- requires_grad=True,
- ),
- rigid_options=gs.options.RigidOptions(
- # We use Newton's method because it converges faster than CG, and therefore gives better gradient estimation
- # when using finite difference method
- constraint_solver=gs.constraint_solver.Newton,
- ),
- show_viewer=False,
- )
-
- scene.add_entity(gs.morphs.Plane(pos=(0, 0, 0)))
- scene.add_entity(gs.morphs.Box(size=(1, 1, 1), pos=(10, 10, 0.49)))
- franka = scene.add_entity(
- gs.morphs.MJCF(file="xml/franka_emika_panda/panda.xml"),
- )
-
- scene.build()
- rigid_solver = scene._sim.rigid_solver
- constraint_solver = rigid_solver.constraint_solver
-
- franka.set_qpos([-1.0124, 1.5559, 1.3662, -1.6878, -1.5799, 1.7757, 1.4602, 0.04, 0.04])
-
- # Monkeypatch the constraint resolve function to avoid overwriting the necessary information for computing gradients.
- def constraint_solver_resolve():
- func_solve_init(
- dofs_info=rigid_solver.dofs_info,
- dofs_state=rigid_solver.dofs_state,
- entities_info=rigid_solver.entities_info,
- constraint_state=constraint_solver.constraint_state,
- rigid_global_info=rigid_solver._rigid_global_info,
- static_rigid_sim_config=rigid_solver._static_rigid_sim_config,
- )
- func_solve_body(
- entities_info=rigid_solver.entities_info,
- dofs_info=rigid_solver.dofs_info,
- dofs_state=rigid_solver.dofs_state,
- constraint_state=constraint_solver.constraint_state,
- rigid_global_info=rigid_solver._rigid_global_info,
- static_rigid_sim_config=rigid_solver._static_rigid_sim_config,
- _n_iterations=constraint_solver._n_iterations,
- )
-
- monkeypatch.setattr(constraint_solver, "resolve", constraint_solver_resolve)
-
- # Step once to compute constraint solver's inputs: [mass], [jac], [aref], [efc_D], [force]. We do not call the
- # entire scene.step() because it will overwrite the necessary information that we need to compute the gradients.
- kernel_step_1(
- links_state=rigid_solver.links_state,
- links_info=rigid_solver.links_info,
- joints_state=rigid_solver.joints_state,
- joints_info=rigid_solver.joints_info,
- dofs_state=rigid_solver.dofs_state,
- dofs_info=rigid_solver.dofs_info,
- geoms_state=rigid_solver.geoms_state,
- geoms_info=rigid_solver.geoms_info,
- entities_state=rigid_solver.entities_state,
- entities_info=rigid_solver.entities_info,
- rigid_global_info=rigid_solver._rigid_global_info,
- static_rigid_sim_config=rigid_solver._static_rigid_sim_config,
- contact_island_state=constraint_solver.contact_island.contact_island_state,
- is_forward_pos_updated=True,
- is_forward_vel_updated=True,
- is_backward=False,
- )
- constraint_solver.add_equality_constraints()
- rigid_solver.collider.detection()
- constraint_solver.add_inequality_constraints()
- constraint_solver.resolve()
-
- # Loss function to compute gradients using finite difference method
- def compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force):
- rigid_solver._rigid_global_info.mass_mat.from_numpy(input_mass)
- constraint_solver.constraint_state.jac.from_numpy(input_jac)
- constraint_solver.constraint_state.aref.from_numpy(input_aref)
- constraint_solver.constraint_state.efc_D.from_numpy(input_efc_D)
- rigid_solver.dofs_state.force.from_numpy(input_force)
-
- # Recompute acc_smooth from the updated input variables
- updated_acc_smooth = np.linalg.solve(input_mass[..., 0], input_force[..., 0])
- rigid_solver.dofs_state.acc_smooth.from_numpy(updated_acc_smooth[..., None])
- constraint_solver.resolve()
-
- output_qacc = qd_to_torch(constraint_solver.qacc)
- return ((output_qacc - target_qacc) ** 2).mean()
-
- init_input_mass = qd_to_numpy(rigid_solver._rigid_global_info.mass_mat, copy=True)
- init_input_jac = qd_to_numpy(constraint_solver.constraint_state.jac, copy=True)
- init_input_aref = qd_to_numpy(constraint_solver.constraint_state.aref, copy=True)
- init_input_efc_D = qd_to_numpy(constraint_solver.constraint_state.efc_D, copy=True)
- init_input_force = qd_to_numpy(rigid_solver.dofs_state.force, copy=True)
-
- # Initial output of the constraint solver
- set_random_seed(0)
- init_output_qacc = qd_to_torch(constraint_solver.qacc)
- target_qacc = torch.from_numpy(np.random.randn(*init_output_qacc.shape)).to(device=gs.device)
- target_qacc = target_qacc * init_output_qacc.abs().mean()
-
- # Solve the constraint solver and get the output
- output_qacc = qd_to_torch(constraint_solver.qacc, copy=True).requires_grad_(True)
-
- # Compute loss and gradient of the output
- loss = ((output_qacc - target_qacc) ** 2).mean()
- dL_dqacc = tensor_to_array(torch.autograd.grad(loss, output_qacc)[0])
-
- # Compute gradients of the input variables: [mass], [jac], [aref], [efc_D], [force]
- constraint_solver.backward(dL_dqacc)
-
- # Fetch gradients of the input variables
- dL_dM = qd_to_numpy(constraint_solver.constraint_state.dL_dM)
- dL_djac = qd_to_numpy(constraint_solver.constraint_state.dL_djac)
- dL_daref = qd_to_numpy(constraint_solver.constraint_state.dL_daref)
- dL_defc_D = qd_to_numpy(constraint_solver.constraint_state.dL_defc_D)
- dL_dforce = qd_to_numpy(constraint_solver.constraint_state.dL_dforce)
-
- ### Compute directional derivatives along random directions
- FD_EPS = 1e-3
- TRIALS = 200
-
- for dL_dx, x_type in (
- (dL_dforce, "force"),
- (dL_daref, "aref"),
- (dL_defc_D, "efc_D"),
- (dL_djac, "jac"),
- (dL_dM, "mass"),
- ):
- dL_error = 0.0
- for _ in range(TRIALS):
- rand_dx = np.random.randn(*dL_dx.shape)
- rand_dx = rand_dx / max(
- np.linalg.norm(rand_dx, axis=0 if x_type in ("force", "aref", "efc_D") else (0, 1)), gs.EPS
- )
- if x_type == "mass":
- # Make rand_dx symmetric
- rand_dx = (rand_dx + np.moveaxis(rand_dx, 0, 1)) * 0.5
-
- dL = (rand_dx * dL_dx).sum()
-
- input_force = init_input_force
- input_aref = init_input_aref
- input_efc_D = init_input_efc_D
- input_jac = init_input_jac
- input_mass = init_input_mass
-
- # 1 * eps
- if x_type == "force":
- input_force = init_input_force + rand_dx * FD_EPS
- elif x_type == "aref":
- input_aref = init_input_aref + rand_dx * FD_EPS
- elif x_type == "efc_D":
- input_efc_D = init_input_efc_D + rand_dx * FD_EPS
- elif x_type == "jac":
- input_jac = init_input_jac + rand_dx * FD_EPS
- elif x_type == "mass":
- input_mass = init_input_mass + rand_dx * FD_EPS
- lossP1 = compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force)
-
- # -1 * eps
- if x_type == "force":
- input_force = init_input_force - rand_dx * FD_EPS
- elif x_type == "aref":
- input_aref = init_input_aref - rand_dx * FD_EPS
- elif x_type == "efc_D":
- input_efc_D = init_input_efc_D - rand_dx * FD_EPS
- elif x_type == "jac":
- input_jac = init_input_jac - rand_dx * FD_EPS
- elif x_type == "mass":
- input_mass = init_input_mass - rand_dx * FD_EPS
-
- lossP2 = compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force)
- dL_fd = (lossP1 - lossP2) / (2 * FD_EPS)
-
- dL_error += (dL - dL_fd).abs() / max(abs(dL), abs(dL_fd), gs.EPS)
-
- dL_error /= TRIALS
- assert_allclose(dL_error, 0.0, atol=RTOL)
-
-
-@pytest.mark.slow # ~250s
-@pytest.mark.required
-@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
-def test_differentiable_rigid(show_viewer):
- dt = 1e-2
- horizon = 100
- substeps = 1
- goal_pos = gs.tensor([0.7, 1.0, 0.05])
- goal_quat = gs.tensor([0.3, 0.2, 0.1, 0.9])
- goal_quat = goal_quat / torch.norm(goal_quat, dim=-1, keepdim=True)
-
- scene = gs.Scene(
- sim_options=gs.options.SimOptions(
- dt=dt,
- substeps=substeps,
- requires_grad=True,
- gravity=(0, 0, -1),
- ),
- rigid_options=gs.options.RigidOptions(
- enable_collision=False,
- enable_self_collision=False,
- enable_joint_limit=False,
- disable_constraint=True,
- use_contact_island=False,
- use_hibernation=False,
- ),
- viewer_options=gs.options.ViewerOptions(
- camera_pos=(2.5, -0.15, 2.42),
- camera_lookat=(0.5, 0.5, 0.1),
- ),
- show_viewer=show_viewer,
- )
-
- box = scene.add_entity(
- gs.morphs.Box(
- pos=(0, 0, 0),
- size=(0.1, 0.1, 0.2),
- ),
- surface=gs.surfaces.Default(
- color=(0.9, 0.0, 0.0, 1.0),
- ),
- )
- if show_viewer:
- target = scene.add_entity(
- gs.morphs.Box(
- pos=goal_pos,
- quat=goal_quat,
- size=(0.1, 0.1, 0.2),
- ),
- surface=gs.surfaces.Default(
- color=(0.0, 0.9, 0.0, 0.5),
- ),
- )
-
- scene.build()
-
- num_iter = 200
- lr = 1e-2
-
- init_pos = gs.tensor([0.3, 0.1, 0.28], requires_grad=True)
- init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True)
- optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr)
-
- scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3)
-
- for _ in range(num_iter):
- scene.reset()
-
- box.set_pos(init_pos)
- box.set_quat(init_quat)
-
- loss = 0
- for _ in range(horizon):
- scene.step()
- if show_viewer:
- target.set_pos(goal_pos)
- target.set_quat(goal_quat)
-
- box_state = box.get_state()
- box_pos = box_state.pos
- box_quat = box_state.quat
- loss = torch.abs(box_pos - goal_pos).sum() + torch.abs(box_quat - goal_quat).sum()
-
- optimizer.zero_grad()
- loss.backward() # this lets gradient flow all the way back to tensor input
- optimizer.step()
- scheduler.step()
-
- with torch.no_grad():
- init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True)
-
- assert_allclose(loss, 0.0, atol=1e-2)
-
-
-@pytest.mark.required
-@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
-def test_diff_sim_vs_solver_state_grad_parity(show_viewer):
- scene = gs.Scene(
- sim_options=gs.options.SimOptions(
- dt=0.01,
- gravity=(0.0, 0.0, 0.0),
- requires_grad=True,
- ),
- rigid_options=gs.options.RigidOptions(
- enable_collision=False,
- ),
- show_viewer=show_viewer,
- )
- robot = scene.add_entity(
- gs.morphs.Box(
- size=(0.1, 0.1, 0.1),
- pos=(0, 0, 0),
- )
- )
- scene.build()
-
- ctrl = gs.tensor(np.random.randn(robot.n_dofs), dtype=gs.tc_float, requires_grad=True)
-
- grads = []
- for use_sim_state in (False, True):
- scene.reset()
-
- robot.set_dofs_velocity(ctrl)
- scene.step()
-
- if use_sim_state:
- solver_state = scene.get_state().solvers_state[scene.solvers.index(scene.rigid_solver)]
- chassis_pos = solver_state.links_pos[:, 0].squeeze()
- else:
- chassis_pos = robot.get_state().pos.squeeze()
-
- loss = torch.linalg.norm(chassis_pos)
- loss.backward()
- grad = ctrl.grad.detach().clone()
- ctrl.grad.zero_()
-
- # Basic sanity check
- assert (grad[..., :3].abs() > gs.EPS).all()
- assert (grad[..., 3:].abs() < gs.EPS).all()
-
- grads.append(grad)
-
- assert_allclose(*grads, atol=gs.EPS)
diff --git a/tests/test_grad_fd.py b/tests/test_grad_fd.py
new file mode 100644
index 0000000000..536ea9cac2
--- /dev/null
+++ b/tests/test_grad_fd.py
@@ -0,0 +1,1932 @@
+"""FD-vs-analytical gradient checks for the differentiable rigid solver.
+
+This is the *correctness* layer of the grad test suite: every test compares the
+diff-mode analytical gradient against central finite differences. It is split
+into three sections, innermost backward layer first:
+
+ 1. Forward-kinematics (constraints OFF): the unconstrained FK + velocity
+ gradient, per joint topology — the base local-gradient bar.
+ 2. Joint-limit (enable_joint_limit=True): the joint-limit inequality
+ constraint reverse, plus forward enforcement.
+ 3. Contact (enable_collision=True): the collision constraint +
+ diff-GJK narrow-phase reverse (box-box and plane-convex).
+
+Each section keeps its OWN scene builder — the FK builder disables all
+constraints, while the joint-limit / contact builders turn the relevant
+constraint on — so the configs never bleed across sections.
+"""
+
+import math
+
+import numpy as np
+import pytest
+import torch
+
+import genesis as gs
+from genesis.utils import set_random_seed
+from genesis.utils.geom import R_to_quat
+from genesis.utils.misc import qd_to_numpy, qd_to_torch, tensor_to_array
+
+from .utils import assert_allclose
+
+
+pytestmark = [
+ pytest.mark.debug(False),
+]
+
+
+# Per-precision FD tolerance. fp32 is intentionally looser.
+# The "quat" kind covers outputs that go through a non-linear pose composition
+# (set_dofs_velocity → state.quat) where Genesis autograd is currently a ~1%
+# noisier than FD.
+_TOL = {
+ ("64", "default"): dict(rtol=1e-4, atol=1e-6, eps=1e-5),
+ ("64", "quat"): dict(rtol=2e-2, atol=1e-3, eps=1e-5),
+ ("32", "default"): dict(rtol=2e-2, atol=2e-3, eps=1e-3),
+ ("32", "quat"): dict(rtol=5e-2, atol=5e-3, eps=1e-3),
+}
+
+
+_PRECISION_PARAMS = [
+ pytest.param("64", marks=pytest.mark.precision("64"), id="fp64"),
+ pytest.param("32", marks=pytest.mark.precision("32"), id="fp32"),
+]
+
+_N_ENVS_PARAMS = [
+ pytest.param(0, id="single"),
+ pytest.param(4, id="batched"),
+]
+
+_SUBSTEPS_PARAMS = [
+ pytest.param(1, id="ss1"),
+ pytest.param(4, id="ss4"),
+]
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _build_scene(mjcf_path: str, *, requires_grad: bool, n_envs: int = 0, substeps: int = 1):
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(
+ dt=0.01,
+ substeps=substeps,
+ gravity=(0.0, 0.0, -9.81),
+ requires_grad=requires_grad,
+ ),
+ rigid_options=gs.options.RigidOptions(
+ enable_collision=False,
+ enable_self_collision=False,
+ enable_joint_limit=False,
+ disable_constraint=True,
+ use_hibernation=False,
+ use_contact_island=False,
+ ),
+ show_viewer=False,
+ )
+ robot = scene.add_entity(gs.morphs.MJCF(file=mjcf_path))
+ scene.build(n_envs=n_envs)
+ return scene, robot
+
+
+def _make_scene_pair(mjcf_file: str, n_envs: int = 0, substeps: int = 1):
+ """Build two parallel scenes from the same MJCF:
+
+ * `scene_ana` runs the differentiable-mode forward and is the only one we
+ ever call `loss.backward()` on. Once a backward has run, that scene's
+ internal target-replay state is left in a configuration that silently
+ ignores subsequent setters — so reusing it for FD probes would give
+ loss_p == loss_m and a fake zero gradient.
+ * `scene_fd` runs the production forward (`requires_grad=False`) and is
+ what FD perturbs. By construction it never sees a backward, so each
+ reset → set → step cycle is clean.
+
+ FD therefore checks "does the diff-mode analytical gradient match the
+ production forward's local sensitivity". With `n_envs > 0` both scenes
+ run in batched mode so we can verify that per-env adjoints are
+ independently correct.
+ """
+ scene_ana, robot_ana = _build_scene(mjcf_file, requires_grad=True, n_envs=n_envs, substeps=substeps)
+ scene_fd, robot_fd = _build_scene(mjcf_file, requires_grad=False, n_envs=n_envs, substeps=substeps)
+ return scene_ana, robot_ana, scene_fd, robot_fd, mjcf_file
+
+
+def _batch_size(scene) -> int:
+ """Effective batch dimension. scene.n_envs == 0 still allocates B=1 internally."""
+ return scene.n_envs if scene.n_envs > 0 else 1
+
+
+def _input_shape(base_shape, n_envs):
+ """Setter inputs are unbatched when n_envs==0; batched (n_envs, *base) otherwise."""
+ return (n_envs,) + tuple(base_shape) if n_envs > 0 else tuple(base_shape)
+
+
+def _solver_state(scene):
+ """Return the rigid solver's RigidSolverState (grad-aware; provides
+ links_pos / links_quat for every link in the entity)."""
+ state = scene.get_state()
+ return state.solvers_state[scene.solvers.index(scene.rigid_solver)]
+
+
+def _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input, # 1-D numpy array (fp64)
+ apply_fn, # callable(robot, x): apply x via a @tracked setter
+ loss_fn, # callable(scene, robot) -> scalar tensor
+ *,
+ label: str,
+ rtol: float = 1e-4,
+ atol: float = 1e-6,
+ eps: float = 1e-5,
+):
+ # NOTE on tolerances: the production-mode and diff-mode forward kernels were
+ # verified to produce bit-identical state.pos/state.quat for the same input
+ # (probe_optionB.py, 2026-05-03), so an FD probed on the no-grad scene
+ # is a valid reference for the diff scene's analytical gradient.
+ #
+ # Most input/output pairs hit rtol=1e-4 trivially. The set_dofs_velocity →
+ # state.quat path is the outlier: it carries a known ~1% systematic drift
+ # between Genesis autograd and central FD (output magnitude is ~1e-2, so
+ # the absolute mismatch sits at ~1e-4 — well above truncation/roundoff
+ # at fp64). Tracked as a separate followup; for those cases callers should
+ # pass a looser `rtol` (e.g. 2e-2) rather than tightening this default.
+ base_np = np.asarray(init_input, dtype=np.float64).copy()
+
+ # --- analytical (diff-mode scene) ---
+ x_ana = gs.tensor(base_np, dtype=gs.tc_float, requires_grad=True)
+ scene_ana.reset()
+ apply_fn(robot_ana, x_ana)
+ scene_ana.step()
+ loss = loss_fn(scene_ana, robot_ana)
+ assert loss.requires_grad, f"[{label}] loss does not require grad — output is not grad-aware"
+ loss.backward()
+ assert x_ana.grad is not None, f"[{label}] x.grad is None after backward"
+ ana_grad = x_ana.grad.detach().cpu().numpy().copy()
+
+ # --- central FD (production-mode scene) ---
+ n = base_np.size
+ fd_grad = np.zeros_like(base_np)
+ for i in range(n):
+ plus = base_np.copy()
+ plus.reshape(-1)[i] = base_np.reshape(-1)[i] + eps
+ scene_fd.reset()
+ apply_fn(robot_fd, gs.tensor(plus, dtype=gs.tc_float))
+ scene_fd.step()
+ loss_p = float(loss_fn(scene_fd, robot_fd).detach().cpu())
+
+ minus = base_np.copy()
+ minus.reshape(-1)[i] = base_np.reshape(-1)[i] - eps
+ scene_fd.reset()
+ apply_fn(robot_fd, gs.tensor(minus, dtype=gs.tc_float))
+ scene_fd.step()
+ loss_m = float(loss_fn(scene_fd, robot_fd).detach().cpu())
+
+ fd_grad.reshape(-1)[i] = (loss_p - loss_m) / (2.0 * eps)
+
+ assert_allclose(
+ torch.from_numpy(ana_grad),
+ torch.from_numpy(fd_grad),
+ rtol=rtol,
+ atol=atol,
+ err_msg=f"[{label}] FD vs analytical mismatch",
+ )
+
+
+def _grad_matches_fd_multistep(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_inputs, # list[np.ndarray] — one input per timestep, each shape matches the setter's expectation
+ apply_fn, # callable(robot, x): apply x via a @tracked setter
+ loss_fn, # callable(scene, robot) -> scalar tensor
+ *,
+ label: str,
+ rtol: float = 1e-4,
+ atol: float = 1e-6,
+ eps: float = 1e-5,
+):
+ """Multi-step variant of `_grad_matches_fd`.
+
+ Forwards `N = len(init_inputs)` simulator steps, applying a different
+ `@tracked`-setter input at each step. After `loss.backward()`, the
+ simulator must produce a correct adjoint for each step's input
+ independently (i.e. `scene._backward()` correctly walks the per-substep
+ `process_input_grad` chain).
+
+ The FD reference perturbs each entry of each step's input separately
+ and re-runs the full N-step trajectory on `scene_fd`. Cost is
+ O(N · sum_inputs_size) forward runs of N steps each; with N=10 and
+ n_dofs ~ 3-7 this is ~600-1400 step calls per topology.
+ """
+ N = len(init_inputs)
+ base_np = [np.asarray(inp, dtype=np.float64).copy() for inp in init_inputs]
+
+ # --- analytical (diff-mode scene) ---
+ scene_ana.reset()
+ x_anas = []
+ for t in range(N):
+ x = gs.tensor(base_np[t], dtype=gs.tc_float, requires_grad=True)
+ x_anas.append(x)
+ apply_fn(robot_ana, x)
+ scene_ana.step()
+ loss = loss_fn(scene_ana, robot_ana)
+ assert loss.requires_grad, f"[{label}] loss does not require grad — output is not grad-aware"
+ loss.backward()
+ ana_grads = []
+ for t, x in enumerate(x_anas):
+ assert x.grad is not None, f"[{label}] step {t}: x.grad is None after backward"
+ ana_grads.append(x.grad.detach().cpu().numpy().copy())
+
+ # --- central FD (production-mode scene): for each (t, i) entry, run the
+ # full N-step trajectory twice with the perturbation injected only at
+ # step t. All other steps use the original input.
+ fd_grads = [np.zeros_like(b) for b in base_np]
+
+ def _run_traj_with_perturb(t_perturb, i_perturb, sign):
+ scene_fd.reset()
+ for s in range(N):
+ inp = base_np[s].copy()
+ if s == t_perturb:
+ inp.reshape(-1)[i_perturb] += sign * eps
+ apply_fn(robot_fd, gs.tensor(inp, dtype=gs.tc_float))
+ scene_fd.step()
+ return float(loss_fn(scene_fd, robot_fd).detach().cpu())
+
+ for t in range(N):
+ for i in range(base_np[t].size):
+ loss_p = _run_traj_with_perturb(t, i, +1)
+ loss_m = _run_traj_with_perturb(t, i, -1)
+ fd_grads[t].reshape(-1)[i] = (loss_p - loss_m) / (2.0 * eps)
+
+ for t in range(N):
+ assert_allclose(
+ torch.from_numpy(ana_grads[t]),
+ torch.from_numpy(fd_grads[t]),
+ rtol=rtol,
+ atol=atol,
+ err_msg=f"[{label}] step {t}: FD vs analytical mismatch",
+ )
+
+
+# loss factories — all use sum-of-squared-deviation to a fixed random target so
+# every entry of the input has a nontrivial sensitivity. Targets and outputs are
+# both flattened before the subtraction so multi-link shapes (B, n_links, 3|4)
+# don't trip torch broadcasting.
+def _loss_state_pos(target):
+ flat = target.reshape(-1)
+
+ def _fn(scene, robot):
+ return ((robot.get_state().pos.reshape(-1) - flat) ** 2).sum()
+
+ return _fn
+
+
+def _loss_state_quat(target):
+ flat = target.reshape(-1)
+
+ def _fn(scene, robot):
+ return ((robot.get_state().quat.reshape(-1) - flat) ** 2).sum()
+
+ return _fn
+
+
+def _loss_links_pos(target):
+ flat = target.reshape(-1)
+
+ def _fn(scene, robot):
+ return ((_solver_state(scene).links_pos.reshape(-1) - flat) ** 2).sum()
+
+ return _fn
+
+
+def _loss_links_quat(target):
+ flat = target.reshape(-1)
+
+ def _fn(scene, robot):
+ return ((_solver_state(scene).links_quat.reshape(-1) - flat) ** 2).sum()
+
+ return _fn
+
+
+def _rand_np(shape, seed):
+ rng = np.random.default_rng(seed)
+ return rng.standard_normal(shape).astype(np.float64)
+
+
+def _target(shape, seed):
+ return torch.from_numpy(_rand_np(shape, seed)).to(dtype=gs.tc_float, device=gs.device)
+
+
+# ---------------------------------------------------------------------------
+# Tests — one per joint topology, several (input, output) checks inside.
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("precision", _PRECISION_PARAMS)
+@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS)
+@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS)
+def test_diff_fk_freejoint(show_viewer, n_envs, precision, substeps):
+ """J1: single free body. Covers (n_envs ∈ {0, 4}) × (precision ∈ {fp64, fp32})."""
+ scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair(
+ "xml/grad/free.xml", n_envs=n_envs, substeps=substeps
+ )
+ n_dofs = robot_ana.n_dofs
+ B = _batch_size(scene_ana)
+ tol_default = _TOL[(precision, "default")]
+ tol_quat = _TOL[(precision, "quat")]
+
+ tgt_pos = _target((B, 3), seed=1)
+ tgt_quat = _target((B, 4), seed=2)
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((3,), n_envs), seed=10),
+ apply_fn=lambda r, x: r.set_pos(x),
+ loss_fn=_loss_state_pos(tgt_pos),
+ label="J1 set_pos → state.pos",
+ **tol_default,
+ )
+
+ init_q_shape = _input_shape((4,), n_envs)
+ init_q = np.broadcast_to(np.array([1.0, 0.0, 0.0, 0.0]), init_q_shape).copy()
+ init_q = init_q + 0.05 * _rand_np(init_q_shape, seed=11)
+ init_q = init_q / np.linalg.norm(init_q, axis=-1, keepdims=True)
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=init_q,
+ apply_fn=lambda r, x: r.set_quat(x),
+ loss_fn=_loss_state_quat(tgt_quat),
+ label="J1 set_quat → state.quat",
+ **tol_quat,
+ )
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=12),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_state_pos(tgt_pos),
+ label="J1 set_dofs_velocity → state.pos (after 1 step)",
+ **tol_default,
+ )
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=13),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_state_quat(tgt_quat),
+ label="J1 set_dofs_velocity → state.quat (after 1 step)",
+ **tol_quat,
+ )
+
+ # fp64 only: d(state.pos)/d(force) ≈ dt^2 / (2 * inertia) ≈ 1e-4 after 1
+ # step. At fp32 with FD eps=1e-3 the loss difference is ~1e-7 — at fp32's
+ # precision floor — and the FD probe disagrees with analytical by ~1e-4
+ # absolute, well above the fp32 default tol band. The J2/J3/J4/J5 force
+ # checks below are also fp64-only for the same reason; J2's
+ # `control_dofs_force → state.quat` does pass at fp32 only because its
+ # check uses the wider quat tolerance.
+ if precision == "64":
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=14),
+ apply_fn=lambda r, x: r.control_dofs_force(x),
+ loss_fn=_loss_state_pos(tgt_pos),
+ label="J1 control_dofs_force → state.pos (after 1 step)",
+ **tol_default,
+ )
+
+
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("precision", _PRECISION_PARAMS)
+@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS)
+@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS)
+def test_diff_fk_revolute(show_viewer, n_envs, precision, substeps):
+ """J2: single revolute joint, fixed base. Covers (n_envs ∈ {0, 4}) × (precision ∈ {fp64, fp32})."""
+ scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair(
+ "xml/grad/revolute.xml", n_envs=n_envs, substeps=substeps
+ )
+ n_dofs = robot_ana.n_dofs # = 1
+ B = _batch_size(scene_ana)
+ tol_default = _TOL[(precision, "default")]
+ tol_quat = _TOL[(precision, "quat")]
+
+ tgt_pos = _target((B, 3), seed=21)
+ tgt_quat = _target((B, 4), seed=22)
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=30),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_state_pos(tgt_pos),
+ label="J2 set_dofs_velocity → state.pos",
+ **tol_default,
+ )
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=31),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_state_quat(tgt_quat),
+ label="J2 set_dofs_velocity → state.quat",
+ **tol_quat,
+ )
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=32),
+ apply_fn=lambda r, x: r.control_dofs_force(x),
+ loss_fn=_loss_state_quat(tgt_quat),
+ label="J2 control_dofs_force → state.quat",
+ **tol_quat,
+ )
+
+
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("precision", _PRECISION_PARAMS)
+@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS)
+@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS)
+def test_diff_fk_spherical(show_viewer, n_envs, precision, substeps):
+ """J6: single spherical (ball) joint, fixed base. 3 angular DOFs / 4 qpos
+ (quaternion). Exercises the SPHERICAL branch of
+ `kernel_manual_forward_kinematics_bw` — verifies that the `qloc → qpos[0:4]`
+ chain rule matches FD."""
+ scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair(
+ "xml/grad/spherical.xml", n_envs=n_envs, substeps=substeps
+ )
+ n_dofs = robot_ana.n_dofs # = 3
+ B = _batch_size(scene_ana)
+ tol_default = _TOL[(precision, "default")]
+ tol_quat = _TOL[(precision, "quat")]
+
+ tgt_pos = _target((B, 3), seed=61)
+ tgt_quat = _target((B, 4), seed=62)
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=70),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_state_pos(tgt_pos),
+ label="J6 set_dofs_velocity → state.pos",
+ **tol_default,
+ )
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=71),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_state_quat(tgt_quat),
+ label="J6 set_dofs_velocity → state.quat",
+ **tol_quat,
+ )
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=72),
+ apply_fn=lambda r, x: r.control_dofs_force(x),
+ loss_fn=_loss_state_quat(tgt_quat),
+ label="J6 control_dofs_force → state.quat",
+ **tol_quat,
+ )
+
+
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("precision", _PRECISION_PARAMS)
+@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS)
+@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS)
+def test_diff_fk_prismatic(show_viewer, n_envs, precision, substeps):
+ """J3: single prismatic joint, fixed base. No rotational DOF — skip the quat output."""
+ scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair(
+ "xml/grad/prismatic.xml", n_envs=n_envs, substeps=substeps
+ )
+ n_dofs = robot_ana.n_dofs # = 1
+ B = _batch_size(scene_ana)
+ tol_default = _TOL[(precision, "default")]
+ tgt_pos = _target((B, 3), seed=41)
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=50),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_state_pos(tgt_pos),
+ label="J3 set_dofs_velocity → state.pos",
+ **tol_default,
+ )
+
+ # fp64-only — see J1's control_dofs_force comment for why FD-vs-analytical
+ # on force-driven position is at fp32's precision floor.
+ if precision == "64":
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=51),
+ apply_fn=lambda r, x: r.control_dofs_force(x),
+ loss_fn=_loss_state_pos(tgt_pos),
+ label="J3 control_dofs_force → state.pos",
+ **tol_default,
+ )
+
+
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("precision", _PRECISION_PARAMS)
+@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS)
+@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS)
+def test_diff_fk_cartpole(show_viewer, n_envs, precision, substeps):
+ """J7: cartpole (prismatic cart + revolute pole). 2 links / 2 DOFs.
+ Same chain as J4 (multi-link entity with translation + rotation
+ coupling)."""
+ scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair("xml/cartpole.xml", n_envs=n_envs, substeps=substeps)
+ n_dofs = robot_ana.n_dofs # = 2 (slider + hinge)
+ B = _batch_size(scene_ana)
+ tol_default = _TOL[(precision, "default")]
+ tol_quat = _TOL[(precision, "quat")]
+
+ tgt_links_pos = _target((B, 2, 3), seed=181)
+ tgt_links_quat = _target((B, 2, 4), seed=182)
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=190),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_links_pos(tgt_links_pos),
+ label="J7 set_dofs_velocity → links_pos",
+ **tol_default,
+ )
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=191),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_links_quat(tgt_links_quat),
+ label="J7 set_dofs_velocity → links_quat",
+ **tol_quat,
+ )
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=192),
+ apply_fn=lambda r, x: r.control_dofs_force(x),
+ loss_fn=_loss_links_pos(tgt_links_pos),
+ label="J7 control_dofs_force → links_pos",
+ **tol_default,
+ )
+
+
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("precision", _PRECISION_PARAMS)
+@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS)
+@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS)
+def test_diff_fk_hopper(show_viewer, n_envs, precision, substeps):
+ """J8: hopper, built collision-free. 4 links / 6 DOFs (planar floating base
+ rootx+rootz+rooty, then thigh/leg/foot hinges)."""
+ scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair("xml/hopper.xml", n_envs=n_envs, substeps=substeps)
+ n_dofs = robot_ana.n_dofs # = 6 (rootx, rootz, rooty, thigh, leg, foot)
+ n_links = robot_ana.n_links # = 5 (base + torso, thigh, leg, foot)
+ B = _batch_size(scene_ana)
+ tol_default = _TOL[(precision, "default")]
+ tol_quat = _TOL[(precision, "quat")]
+ # Hopper is the largest topology here (5 links, 6 DOFs). At fp32 the batched
+ # (n_envs=4) FD probe quantizes the small-sensitivity links_pos/links_quat
+ # entries to a ~2e-3 step, leaving a few entries ~6e-3 from the analytical.
+ # fp64 (single + batched) pins correctness; widen only the fp32 atol band so
+ # the FD-floor noise on the larger chain doesn't trip the check.
+ if precision == "32":
+ tol_default = dict(rtol=tol_default["rtol"], atol=8e-3, eps=tol_default["eps"])
+ tol_quat = dict(rtol=tol_quat["rtol"], atol=8e-3, eps=tol_quat["eps"])
+
+ tgt_links_pos = _target((B, n_links, 3), seed=201)
+ tgt_links_quat = _target((B, n_links, 4), seed=202)
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=210),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_links_pos(tgt_links_pos),
+ label="J8 set_dofs_velocity → links_pos",
+ **tol_default,
+ )
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=211),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_links_quat(tgt_links_quat),
+ label="J8 set_dofs_velocity → links_quat",
+ **tol_quat,
+ )
+
+ # fp64-only — see J1's control_dofs_force comment.
+ if precision == "64":
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=212),
+ apply_fn=lambda r, x: r.control_dofs_force(x),
+ loss_fn=_loss_links_pos(tgt_links_pos),
+ label="J8 control_dofs_force → links_pos",
+ **tol_default,
+ )
+
+
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("precision", _PRECISION_PARAMS)
+@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS)
+@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS)
+def test_diff_fk_free_with_revolute(show_viewer, n_envs, precision, substeps):
+ """J4: freejoint root + one revolute child. Outputs use
+ multi-link solver_state.links_pos/quat so the child link's FK is exercised too."""
+ scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair(
+ "xml/grad/free_with_revolute.xml", n_envs=n_envs, substeps=substeps
+ )
+ n_dofs = robot_ana.n_dofs # 6 free + 1 hinge = 7
+ n_links = robot_ana.n_links # 2
+ B = _batch_size(scene_ana)
+ tol_default = _TOL[(precision, "default")]
+ tol_quat = _TOL[(precision, "quat")]
+ tgt_links_pos = _target((B, n_links, 3), seed=61)
+ tgt_links_quat = _target((B, n_links, 4), seed=62)
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((3,), n_envs), seed=70),
+ apply_fn=lambda r, x: r.set_pos(x),
+ loss_fn=_loss_links_pos(tgt_links_pos),
+ label="J4 set_pos → links_pos",
+ **tol_default,
+ )
+
+ init_q_shape = _input_shape((4,), n_envs)
+ init_q = np.broadcast_to(np.array([1.0, 0.0, 0.0, 0.0]), init_q_shape).copy()
+ init_q = init_q + 0.05 * _rand_np(init_q_shape, seed=71)
+ init_q = init_q / np.linalg.norm(init_q, axis=-1, keepdims=True)
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=init_q,
+ apply_fn=lambda r, x: r.set_quat(x),
+ loss_fn=_loss_links_quat(tgt_links_quat),
+ label="J4 set_quat → links_quat",
+ **tol_quat,
+ )
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=72),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_links_pos(tgt_links_pos),
+ label="J4 set_dofs_velocity → links_pos",
+ **tol_default,
+ )
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=73),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_links_quat(tgt_links_quat),
+ label="J4 set_dofs_velocity → links_quat",
+ **tol_quat,
+ )
+
+ # fp64-only — see J1's control_dofs_force comment.
+ if precision == "64":
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=74),
+ apply_fn=lambda r, x: r.control_dofs_force(x),
+ loss_fn=_loss_links_pos(tgt_links_pos),
+ label="J4 control_dofs_force → links_pos",
+ **tol_default,
+ )
+
+
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("precision", _PRECISION_PARAMS)
+@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS)
+@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS)
+def test_diff_fk_revolute_chain3(show_viewer, n_envs, precision, substeps):
+ """J5: 3-link serial revolute chain, fixed base. Tests deeper FK chain."""
+ scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair(
+ "xml/grad/revolute_chain3.xml", n_envs=n_envs, substeps=substeps
+ )
+ n_dofs = robot_ana.n_dofs # 3
+ n_links = robot_ana.n_links # 3
+ B = _batch_size(scene_ana)
+ tol_default = _TOL[(precision, "default")]
+ tol_quat = _TOL[(precision, "quat")]
+ tgt_links_pos = _target((B, n_links, 3), seed=81)
+ tgt_links_quat = _target((B, n_links, 4), seed=82)
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=90),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_links_pos(tgt_links_pos),
+ label="J5 set_dofs_velocity → links_pos",
+ **tol_default,
+ )
+
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=91),
+ apply_fn=lambda r, x: r.set_dofs_velocity(x),
+ loss_fn=_loss_links_quat(tgt_links_quat),
+ label="J5 set_dofs_velocity → links_quat",
+ **tol_quat,
+ )
+
+ # fp64-only — see J1's control_dofs_force comment.
+ if precision == "64":
+ _grad_matches_fd(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=92),
+ apply_fn=lambda r, x: r.control_dofs_force(x),
+ loss_fn=_loss_links_pos(tgt_links_pos),
+ label="J5 control_dofs_force → links_pos",
+ **tol_default,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Multi-step gradient verification — exercises cross-step adjoint propagation.
+# ---------------------------------------------------------------------------
+
+
+_MULTISTEP_TOPOLOGIES = [
+ pytest.param("xml/grad/free.xml", "J1 freejoint", 6, _loss_state_pos, (3,), 161, id="J1_free"),
+ pytest.param("xml/grad/revolute.xml", "J2 revolute", 1, _loss_state_pos, (3,), 162, id="J2_revolute"),
+ pytest.param("xml/grad/prismatic.xml", "J3 prismatic", 1, _loss_state_pos, (3,), 163, id="J3_prismatic"),
+ pytest.param(
+ "xml/grad/free_with_revolute.xml", "J4 free+revolute", 7, _loss_links_pos, (2, 3), 164, id="J4_free_rev"
+ ),
+ pytest.param("xml/grad/revolute_chain3.xml", "J5 chain3", 3, _loss_links_pos, (3, 3), 165, id="J5_chain3"),
+ pytest.param("xml/grad/spherical.xml", "J6 spherical", 3, _loss_state_pos, (3,), 166, id="J6_spherical"),
+ pytest.param("xml/cartpole.xml", "J7 cartpole", 2, _loss_links_pos, (2, 3), 167, id="J7_cartpole"),
+ pytest.param("xml/hopper.xml", "J8 hopper", 6, _loss_links_pos, (5, 3), 168, id="J8_hopper"),
+]
+
+
+@pytest.mark.required
+@pytest.mark.precision("64")
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("mjcf_str, name, n_dofs, loss_factory, output_shape, seed", _MULTISTEP_TOPOLOGIES)
+@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS)
+def test_diff_fk_multistep_control_force(
+ show_viewer, mjcf_str, name, n_dofs, loss_factory, output_shape, seed, substeps
+):
+ """Per-topology check that `control_dofs_force` applied with a *different*
+ input at each of N=10 steps produces per-step gradients that match FD.
+
+ fp64 + single env only: N=10 with batched + fp32 makes the test slow
+ (~30s per topology) and the fp32 + batched ulps-level noise across
+ multiple steps stacks up enough to require relaxed tolerances that
+ obscure real bugs. Single-step tests already cover fp32 + batched
+ against the same setter.
+ """
+ scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair(mjcf_str, n_envs=0, substeps=substeps)
+ B = _batch_size(scene_ana)
+ target = _target((B, *output_shape), seed=seed)
+
+ # 10 distinct force inputs, one per step.
+ N = 10
+ init_inputs = [_rand_np((n_dofs,), seed=seed * 100 + t) for t in range(N)]
+
+ _grad_matches_fd_multistep(
+ scene_ana,
+ robot_ana,
+ scene_fd,
+ robot_fd,
+ init_inputs=init_inputs,
+ apply_fn=lambda r, x: r.control_dofs_force(x),
+ loss_fn=loss_factory(target),
+ label=f"{name} control_dofs_force × {N} steps",
+ )
+
+
+# ===========================================================================
+# Joint-limit constraint FD (enable_joint_limit=True -> constraints ON)
+# ===========================================================================
+
+
+def _build(mjcf_path: str, *, requires_grad: bool, enable_joint_limit: bool):
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(
+ dt=1.0 / 60.0,
+ substeps=4,
+ gravity=(0.0, 0.0, 0.0),
+ requires_grad=requires_grad,
+ ),
+ rigid_options=gs.options.RigidOptions(
+ enable_collision=False,
+ enable_self_collision=False,
+ enable_joint_limit=enable_joint_limit,
+ disable_constraint=not enable_joint_limit,
+ use_hibernation=False,
+ use_contact_island=False,
+ ),
+ show_viewer=False,
+ )
+ robot = scene.add_entity(gs.morphs.MJCF(file=mjcf_path))
+ scene.build(n_envs=0)
+ return scene, robot
+
+
+def _rigid_state(scene):
+ return scene.get_state().solvers_state[scene.solvers.index(scene.rigid_solver)]
+
+
+@pytest.mark.required
+@pytest.mark.precision("64")
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+def test_diff_joint_limit_forward_enforcement(show_viewer):
+ """With `enable_joint_limit=True` and slider range=[-4,4], pushing the
+ cart at v=100 for 60 steps must keep |x| bounded; with the limit off
+ the cart drifts past 90 m (control case)."""
+ mjcf_path = "xml/grad/slider_limit.xml"
+
+ # Control: limit OFF
+ scene, robot = _build(mjcf_path, requires_grad=False, enable_joint_limit=False)
+ scene.reset()
+ robot.set_dofs_velocity(gs.tensor([100.0], dtype=gs.tc_float))
+ for _ in range(60):
+ scene.step()
+ x_off = float(_rigid_state(scene).qpos[0, 0].detach())
+ assert abs(x_off) > 50.0, f"control (limit OFF) cart should drift past 50m, got x={x_off}"
+
+ # Limit ON — should stay bounded.
+ scene, robot = _build(mjcf_path, requires_grad=False, enable_joint_limit=True)
+ scene.reset()
+ robot.set_dofs_velocity(gs.tensor([100.0], dtype=gs.tc_float))
+ for _ in range(60):
+ scene.step()
+ x_on = float(_rigid_state(scene).qpos[0, 0].detach())
+ assert abs(x_on) <= 4.5, f"limit ON should keep |x| <= 4.5 (small margin for soft constraint), got x={x_on}"
+
+
+@pytest.mark.required
+@pytest.mark.precision("64")
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("init_vel", [0.5, 5.0])
+def test_diff_joint_limit_backward_finite_no_limit_hit(show_viewer, init_vel):
+ """When the rollout stays well inside the joint range, the joint-limit
+ branch should *not* activate (`pos_delta >= 0`), and the gradient should
+ match the no-limit baseline almost byte-exactly. This pins the
+ "limit-on but inactive" path through the constraint solver."""
+ mjcf_path = "xml/grad/slider_limit.xml"
+ N_STEPS = 1 # short — cart doesn't reach limit
+
+ grads = {}
+ for limit in (False, True):
+ scene, robot = _build(mjcf_path, requires_grad=True, enable_joint_limit=limit)
+ scene.reset()
+ v = gs.tensor([init_vel], dtype=gs.tc_float, requires_grad=True)
+ robot.set_dofs_velocity(v)
+ for _ in range(N_STEPS):
+ scene.step()
+ loss = (_rigid_state(scene).qpos[0, 0]) ** 2
+ loss.backward()
+ assert v.grad is not None, f"limit={limit}: v.grad is None"
+ g = float(v.grad[0])
+ assert math.isfinite(g), f"limit={limit}: gradient is not finite ({g})"
+ grads[limit] = g
+
+ # Limit-inactive case should match the no-limit baseline tightly — the
+ # constraint branch only runs `n_constraints += 0`, so the autograd tape
+ # should be identical up to floating-point.
+ assert_allclose(grads[True], grads[False], rtol=1e-6, atol=1e-9)
+
+
+@pytest.mark.required
+@pytest.mark.precision("64")
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+def test_diff_joint_limit_backward_fd_one_step(show_viewer):
+ """FD vs analytical gradient when the cart starts well inside the limit
+ and takes a single step — verifies the constraint-solver-inclusive
+ forward+backward chain still satisfies central FD."""
+ mjcf_path = "xml/grad/slider_limit.xml"
+ init_vel = 2.0
+ eps = 1e-5
+
+ # Analytical
+ scene_ana, robot_ana = _build(mjcf_path, requires_grad=True, enable_joint_limit=True)
+ scene_ana.reset()
+ v = gs.tensor([init_vel], dtype=gs.tc_float, requires_grad=True)
+ robot_ana.set_dofs_velocity(v)
+ scene_ana.step()
+ loss = (_rigid_state(scene_ana).qpos[0, 0]) ** 2
+ loss.backward()
+ ana = float(v.grad[0])
+
+ # FD
+ scene_fd, robot_fd = _build(mjcf_path, requires_grad=False, enable_joint_limit=True)
+
+ def loss_at(val: float) -> float:
+ scene_fd.reset()
+ robot_fd.set_dofs_velocity(gs.tensor([val], dtype=gs.tc_float))
+ scene_fd.step()
+ return float((_rigid_state(scene_fd).qpos[0, 0]) ** 2)
+
+ fd = (loss_at(init_vel + eps) - loss_at(init_vel - eps)) / (2 * eps)
+
+ assert_allclose(ana, fd, rtol=1e-3, atol=1e-6)
+
+
+# (init_vel, n_steps) cases where the cart actually crosses |x|=4 during the
+# rollout. Each case engages the constraint solver during the integration —
+# they cover the `M^{-1} J^T λ` correction path that the unconstrained
+# `kernel_manual_compute_qacc_bw` could not produce. Resolved 2026-05-25 by
+# wiring `constraint_solver.backward` + `kernel_manual_add_joint_limit_constraints_bw`
+# into `substep_pre_coupling_grad`.
+_FD_ACTIVE_CASES = [
+ (500.0, 1),
+ (200.0, 2),
+ (100.0, 5),
+ (50.0, 10),
+]
+
+
+@pytest.mark.required
+@pytest.mark.precision("64")
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("init_vel,n_steps", _FD_ACTIVE_CASES)
+def test_diff_joint_limit_backward_fd_active(show_viewer, init_vel, n_steps):
+ """FD vs analytical when the cart actually crosses the joint limit during
+ the rollout. Exercises the constrained backward path
+ (`constraint_solver.backward` + `kernel_manual_add_joint_limit_constraints_bw`)
+ on cases where the unconstrained IFT alone would drop the `M^{-1} J^T λ`
+ correction and disagree with FD (often sign-flipped). Snapshot of the
+ expected gradients (FP64 CPU, FD with eps=1e-4) at fix time:
+
+ init_vel n_steps x_final v.grad
+ 500 1 +4.464 +5.51e-2
+ 200 2 +4.203 +9.46e-2
+ 100 5 +3.892 -1.60e-1
+ 50 10 +3.606 -1.16e+0
+ """
+ mjcf_path = "xml/grad/slider_limit.xml"
+ eps = 1e-4
+
+ # Analytical
+ scene_ana, robot_ana = _build(mjcf_path, requires_grad=True, enable_joint_limit=True)
+ scene_ana.reset()
+ v = gs.tensor([init_vel], dtype=gs.tc_float, requires_grad=True)
+ robot_ana.set_dofs_velocity(v)
+ for _ in range(n_steps):
+ scene_ana.step()
+ x_final = float(_rigid_state(scene_ana).qpos[0, 0].detach())
+ # Setup sanity: the cart must have entered the limit band, otherwise this
+ # case wouldn't actually exercise the constraint correction path.
+ assert abs(x_final) > 3.5, (
+ f"setup error: init_vel={init_vel}, n_steps={n_steps} did not bring "
+ f"the cart near the limit (x_final={x_final}); pick a larger v0 or "
+ f"more steps."
+ )
+ loss = (_rigid_state(scene_ana).qpos[0, 0]) ** 2
+ loss.backward()
+ ana = float(v.grad[0])
+
+ # FD
+ scene_fd, robot_fd = _build(mjcf_path, requires_grad=False, enable_joint_limit=True)
+
+ def loss_at(val: float) -> float:
+ scene_fd.reset()
+ robot_fd.set_dofs_velocity(gs.tensor([val], dtype=gs.tc_float))
+ for _ in range(n_steps):
+ scene_fd.step()
+ return float((_rigid_state(scene_fd).qpos[0, 0]) ** 2)
+
+ fd = (loss_at(init_vel + eps) - loss_at(init_vel - eps)) / (2 * eps)
+
+ assert_allclose(ana, fd, rtol=1e-3, atol=1e-6)
+
+
+# Per-step force horizons that drive the cart into the slider limit through
+# `control_dofs_force`. Constant +500 N over `n_steps` accelerates the
+# unit-mass cart past |x|=4 within ~10 substep-groups at dt=1/60, substeps=4
+# (default solref); shorter horizons leave the cart inside the band and
+# don't exercise the constraint backward, so we restrict to multi-step
+# active cases. n_steps=10 probes whether the per-step `force.grad` for
+# early-horizon steps leaks a wrong gradient when the constrained backward
+# chain (`constraint_solver.backward` + manual joint-limit BW +
+# `fwd_dynamics_without_qacc.grad` accumulation) runs across many substeps.
+_FD_FORCE_CASES = [10]
+
+
+@pytest.mark.required
+@pytest.mark.precision("64")
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("n_steps", _FD_FORCE_CASES)
+def test_diff_joint_limit_backward_fd_per_step_force(show_viewer, n_steps):
+ """Central-FD vs analytical gradient on a per-step `control_dofs_force`
+ time series that drives the cart into the slider limit during the
+ rollout. Multi-step variant of `test_diff_joint_limit_backward_fd_active`.
+ """
+ mjcf_path = "xml/grad/slider_limit.xml"
+ eps = 1e-2
+ force_value = 500.0
+ init_force = np.full((n_steps, 1), force_value, dtype=np.float64)
+
+ # Analytical
+ scene_ana, robot_ana = _build(mjcf_path, requires_grad=True, enable_joint_limit=True)
+ scene_ana.reset()
+ forces = [gs.tensor(init_force[t], dtype=gs.tc_float, requires_grad=True) for t in range(n_steps)]
+ for t in range(n_steps):
+ robot_ana.control_dofs_force(forces[t])
+ scene_ana.step()
+ x_final = float(_rigid_state(scene_ana).qpos[0, 0].detach())
+ # Setup sanity: the cart must have entered the limit band, otherwise this
+ # case wouldn't actually exercise the multi-step constraint backward.
+ assert abs(x_final) > 3.5, (
+ f"setup error: n_steps={n_steps} at force={force_value} did not bring "
+ f"the cart near the limit (x_final={x_final}); pick a larger force or "
+ f"more steps."
+ )
+ loss = (_rigid_state(scene_ana).qpos[0, 0]) ** 2
+ loss.backward()
+ for t, f in enumerate(forces):
+ assert f.grad is not None, f"step {t}: force.grad is None"
+ ana = np.array([float(f.grad[0]) for f in forces])
+
+ # FD per-step
+ scene_fd, robot_fd = _build(mjcf_path, requires_grad=False, enable_joint_limit=True)
+
+ def loss_at(perturbed: np.ndarray) -> float:
+ scene_fd.reset()
+ for t in range(n_steps):
+ robot_fd.control_dofs_force(gs.tensor(perturbed[t], dtype=gs.tc_float))
+ scene_fd.step()
+ return float((_rigid_state(scene_fd).qpos[0, 0]) ** 2)
+
+ fd = np.zeros(n_steps)
+ for t in range(n_steps):
+ plus = init_force.copy()
+ plus[t, 0] += eps
+ minus = init_force.copy()
+ minus[t, 0] -= eps
+ fd[t] = (loss_at(plus) - loss_at(minus)) / (2 * eps)
+
+ # Per-step comparison so the failure message identifies the offending step.
+ for t in range(n_steps):
+ assert_allclose(
+ ana[t],
+ fd[t],
+ rtol=1e-3,
+ atol=1e-4,
+ err_msg=(
+ f"per-step force.grad mismatch at t={t}/{n_steps} "
+ f"(ana={ana[t]:+.4e}, fd={fd[t]:+.4e}); full ana={ana}, fd={fd}"
+ ),
+ )
+
+
+def _build_cartpole(mjcf_path: str, *, requires_grad: bool):
+ """Build a multi-body cart+pole scene with gravity + slider limit on."""
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(
+ dt=1.0 / 60.0,
+ substeps=4,
+ gravity=(0.0, 0.0, -9.81),
+ requires_grad=requires_grad,
+ ),
+ rigid_options=gs.options.RigidOptions(
+ enable_collision=False,
+ enable_self_collision=False,
+ enable_joint_limit=True,
+ disable_constraint=False,
+ use_hibernation=False,
+ use_contact_island=False,
+ ),
+ show_viewer=False,
+ )
+ robot = scene.add_entity(gs.morphs.MJCF(file=mjcf_path))
+ scene.build(n_envs=0)
+ return scene, robot
+
+
+@pytest.mark.required
+@pytest.mark.precision("64")
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("n_steps", [15])
+def test_diff_joint_limit_backward_fd_per_step_force_cartpole(show_viewer, n_steps):
+ """Multi-body variant of `test_diff_joint_limit_backward_fd_per_step_force`.
+
+ Same per-step `control_dofs_force` setup but on the *cart+pole* MJCF
+ (slider [-4, 4] + free-rotating hinge with gravity). cart DOF is
+ actuated at `force_value`, pole DOF takes 0 force. Loss is `cart_x^2`
+ at the terminal step. Both cart_force.grad AND pole_force.grad are
+ compared to central FD — pole_force.grad is *non-zero* because a
+ hinge torque accelerates the pole, the pole's swing exerts reactive
+ horizontal force on the cart via the hinge, which moves cart_x.
+ """
+ mjcf_path = "xml/cartpole.xml"
+ eps = 1e-2
+ # cart+pole effective mass ≈ 11 (cart 1 + pole 10 horizontal-locked at
+ # hanging), so cart force needs to be larger than the cart-only test
+ # to reach the limit within `n_steps`.
+ force_value = 2000.0
+ # Force shape per step: (n_dofs,) = (cart_f, pole_f). pole stays at 0.
+ init_force = np.zeros((n_steps, 2), dtype=np.float64)
+ init_force[:, 0] = force_value
+
+ # Initial state: cart at x=0, pole hanging down at theta=-pi (same as
+ # `CartPoleSwingUpEnv._init_qpos`). Deterministic; same in ana / FD.
+ init_qpos = [0.0, -math.pi]
+
+ # Analytical
+ scene_ana, robot_ana = _build_cartpole(mjcf_path, requires_grad=True)
+ scene_ana.reset()
+ robot_ana.set_dofs_position(gs.tensor(init_qpos, dtype=gs.tc_float))
+ forces = [gs.tensor(init_force[t], dtype=gs.tc_float, requires_grad=True) for t in range(n_steps)]
+ for t in range(n_steps):
+ robot_ana.control_dofs_force(forces[t])
+ scene_ana.step()
+ x_final = float(_rigid_state(scene_ana).qpos[0, 0].detach())
+ assert abs(x_final) > 3.5, (
+ f"setup error: cart+pole at n_steps={n_steps}, force={force_value} "
+ f"did not bring the cart near the limit (x_final={x_final}); pick a "
+ f"larger force or more steps."
+ )
+ loss = (_rigid_state(scene_ana).qpos[0, 0]) ** 2
+ loss.backward()
+ for t, f in enumerate(forces):
+ assert f.grad is not None, f"step {t}: force.grad is None"
+ # cart-force grad per step (slot 0); slot 1 is pole-force grad, must be 0.
+ ana_cart = np.array([float(f.grad[0]) for f in forces])
+ ana_pole = np.array([float(f.grad[1]) for f in forces])
+
+ # FD per-step on the cart-force slot only.
+ scene_fd, robot_fd = _build_cartpole(mjcf_path, requires_grad=False)
+
+ def loss_at(perturbed: np.ndarray) -> float:
+ scene_fd.reset()
+ robot_fd.set_dofs_position(gs.tensor(init_qpos, dtype=gs.tc_float))
+ for t in range(n_steps):
+ robot_fd.control_dofs_force(gs.tensor(perturbed[t], dtype=gs.tc_float))
+ scene_fd.step()
+ return float((_rigid_state(scene_fd).qpos[0, 0]) ** 2)
+
+ fd_cart = np.zeros(n_steps)
+ fd_pole = np.zeros(n_steps)
+ for t in range(n_steps):
+ plus = init_force.copy()
+ plus[t, 0] += eps
+ minus = init_force.copy()
+ minus[t, 0] -= eps
+ fd_cart[t] = (loss_at(plus) - loss_at(minus)) / (2 * eps)
+
+ plus = init_force.copy()
+ plus[t, 1] += eps
+ minus = init_force.copy()
+ minus[t, 1] -= eps
+ fd_pole[t] = (loss_at(plus) - loss_at(minus)) / (2 * eps)
+
+ # Cart-force grad — straight chain from action to cart_x.
+ for t in range(n_steps):
+ assert_allclose(
+ ana_cart[t],
+ fd_cart[t],
+ rtol=1e-3,
+ atol=1e-4,
+ err_msg=(
+ f"cart-pole cart_force.grad mismatch at t={t}/{n_steps} "
+ f"(ana={ana_cart[t]:+.4e}, fd={fd_cart[t]:+.4e}); "
+ f"full ana={ana_cart}, fd={fd_cart}"
+ ),
+ )
+ # Pole-force grad — hinge torque chain: pole_force -> pole_angle ->
+ # pole COM horizontal accel -> reactive force on cart via hinge ->
+ # cart_x. Non-zero, must still match FD step-by-step.
+ for t in range(n_steps):
+ assert_allclose(
+ ana_pole[t],
+ fd_pole[t],
+ rtol=1e-3,
+ atol=1e-4,
+ err_msg=(
+ f"cart-pole pole_force.grad mismatch at t={t}/{n_steps} "
+ f"(ana={ana_pole[t]:+.4e}, fd={fd_pole[t]:+.4e}); "
+ f"full ana_pole={ana_pole}, fd_pole={fd_pole}"
+ ),
+ )
+
+
+def _build_hopper(mjcf_path: str, *, requires_grad: bool):
+ """Build the hopper collision-free with joint limits ON and gravity off.
+
+ Collision is off so the joint-limit constraint is the only constraint in
+ play (no foot-ground contact); gravity is off so the base doesn't drift,
+ keeping the rollout focused on driving a leg joint into its range limit.
+ """
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(
+ dt=1.0 / 60.0,
+ substeps=4,
+ gravity=(0.0, 0.0, 0.0),
+ requires_grad=requires_grad,
+ ),
+ rigid_options=gs.options.RigidOptions(
+ enable_collision=False,
+ enable_self_collision=False,
+ enable_joint_limit=True,
+ disable_constraint=False,
+ use_hibernation=False,
+ use_contact_island=False,
+ ),
+ show_viewer=False,
+ )
+ robot = scene.add_entity(gs.morphs.MJCF(file=mjcf_path))
+ scene.build(n_envs=0)
+ return scene, robot
+
+
+@pytest.mark.required
+@pytest.mark.precision("64")
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("n_steps", [10])
+def test_diff_joint_limit_backward_fd_per_step_force_hopper(show_viewer, n_steps):
+ """Joint-limit backward FD check on the hopper — combines the multi-joint
+ planar base (slide+slide+hinge on the torso) with an *active* joint-limit
+ constraint, both collision-free.
+
+ A constant torque on the foot joint drives it past its `[-0.785, 0.785]`
+ range during the rollout, engaging the joint-limit inequality constraint.
+ The loss is on the foot LINK world position (so the gradient flows through
+ the full FK chain — exercising `kernel_manual_forward_kinematics_bw`'s multi-joint
+ reverse for the base — as well as the constraint backward). Every DOF's
+ per-step `control_dofs_force.grad` is compared to central FD: the foot DOF
+ is the forced + limited one; the other DOFs are reached only through the
+ articulated coupling, so their FD sensitivity is non-trivial too.
+ """
+ mjcf_path = "xml/hopper.xml"
+ n_dofs = 6 # rootx, rootz, rooty, thigh, leg, foot
+ foot_dof = 5
+ eps = 1e-2
+ force_value = 200.0
+ init_force = np.zeros((n_steps, n_dofs), dtype=np.float64)
+ init_force[:, foot_dof] = force_value
+
+ def _links_pos_sq_loss(scene):
+ lp = _rigid_state(scene).links_pos
+ return (lp.reshape(-1) ** 2).sum()
+
+ # Analytical
+ scene_ana, robot_ana = _build_hopper(mjcf_path, requires_grad=True)
+ scene_ana.reset()
+ forces = [gs.tensor(init_force[t], dtype=gs.tc_float, requires_grad=True) for t in range(n_steps)]
+ for t in range(n_steps):
+ robot_ana.control_dofs_force(forces[t])
+ scene_ana.step()
+ foot_q = float(_rigid_state(scene_ana).qpos[0, foot_dof].detach())
+ # Setup sanity: the foot must have entered its limit band, else the
+ # constraint backward isn't exercised.
+ assert abs(foot_q) > 0.7, (
+ f"setup error: n_steps={n_steps} at foot force={force_value} did not "
+ f"drive the foot joint near its 0.785 limit (foot_q={foot_q}); pick a "
+ f"larger force or more steps."
+ )
+ loss = _links_pos_sq_loss(scene_ana)
+ loss.backward()
+ for t, f in enumerate(forces):
+ assert f.grad is not None, f"step {t}: force.grad is None"
+ ana = np.array([[float(f.grad[d]) for d in range(n_dofs)] for f in forces]) # (n_steps, n_dofs)
+
+ # FD per-step, per-dof
+ scene_fd, robot_fd = _build_hopper(mjcf_path, requires_grad=False)
+
+ def loss_at(perturbed: np.ndarray) -> float:
+ scene_fd.reset()
+ for t in range(n_steps):
+ robot_fd.control_dofs_force(gs.tensor(perturbed[t], dtype=gs.tc_float))
+ scene_fd.step()
+ return float(_links_pos_sq_loss(scene_fd).detach())
+
+ fd = np.zeros((n_steps, n_dofs))
+ for t in range(n_steps):
+ for d in range(n_dofs):
+ plus = init_force.copy()
+ plus[t, d] += eps
+ minus = init_force.copy()
+ minus[t, d] -= eps
+ fd[t, d] = (loss_at(plus) - loss_at(minus)) / (2 * eps)
+
+ for t in range(n_steps):
+ for d in range(n_dofs):
+ assert_allclose(
+ ana[t, d],
+ fd[t, d],
+ rtol=1e-3,
+ atol=1e-4,
+ err_msg=(
+ f"hopper force.grad mismatch at t={t}/{n_steps}, dof={d} "
+ f"(ana={ana[t, d]:+.4e}, fd={fd[t, d]:+.4e})\nfull ana=\n{ana}\nfull fd=\n{fd}"
+ ),
+ )
+
+
+# ===========================================================================
+# Collision / diff-GJK contact FD (enable_collision=True -> constraints ON)
+# ===========================================================================
+
+
+def _build_box_box(*, requires_grad: bool):
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(
+ dt=0.01,
+ substeps=2,
+ gravity=(0.0, 0.0, -9.81),
+ requires_grad=requires_grad,
+ ),
+ rigid_options=gs.options.RigidOptions(
+ enable_collision=True,
+ enable_self_collision=False,
+ enable_joint_limit=False,
+ disable_constraint=False,
+ use_hibernation=False,
+ use_contact_island=False,
+ box_box_detection=False, # general convex-convex GJK (differentiable) path
+ ),
+ show_viewer=False,
+ )
+ scene.add_entity(gs.morphs.Box(size=(2.0, 2.0, 0.2), pos=(0.0, 0.0, 0.1), fixed=True))
+ box = scene.add_entity(gs.morphs.Box(size=(0.4, 0.4, 0.4), pos=(0.0, 0.0, 0.4)))
+ scene.build(n_envs=0)
+ return scene, box
+
+
+def _build_plane_convex(shape: str, *, requires_grad: bool):
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(
+ dt=0.01,
+ substeps=2,
+ gravity=(0.0, 0.0, -9.81),
+ requires_grad=requires_grad,
+ ),
+ rigid_options=gs.options.RigidOptions(
+ enable_collision=True,
+ enable_self_collision=False,
+ enable_joint_limit=False,
+ disable_constraint=False,
+ use_hibernation=False,
+ use_contact_island=False,
+ box_box_detection=False,
+ ),
+ show_viewer=False,
+ )
+ scene.add_entity(gs.morphs.Plane())
+ if shape == "box":
+ obj = scene.add_entity(gs.morphs.Box(size=(0.4, 0.4, 0.4), pos=(0.0, 0.0, 0.3)))
+ elif shape == "sphere":
+ obj = scene.add_entity(gs.morphs.Sphere(radius=0.2, pos=(0.0, 0.0, 0.3)))
+ elif shape == "capsule":
+ obj = scene.add_entity(gs.morphs.MJCF(file="xml/grad/capsule.xml", align=False))
+ else:
+ raise ValueError(shape)
+ scene.build(n_envs=0)
+ return scene, obj
+
+
+def _n_contacts(scene) -> int:
+ return int(scene.rigid_solver.collider._collider_state.n_contacts.to_numpy()[0])
+
+
+def _settle(scene, obj, n_settle: int):
+ zero = gs.tensor([0.0] * 6, dtype=gs.tc_float)
+ for _ in range(n_settle):
+ obj.control_dofs_force(zero)
+ scene.step()
+
+
+def _run_fd_per_step_force(build_fn, rest_dofs, *, base_force, n_settle, n_steps, fd_dofs, eps, rtol, atol):
+ """Analytical-vs-central-FD driver for a free convex pressed into a fixed
+ collider by a per-step downward force.
+
+ The force is purely downward/centered: a lateral or torque component would
+ tip the body and change the contact manifold (breaking contact-pair
+ preservation), so only the load-bearing z DOF is FD-checked. Its gradient
+ still runs through contact_pos / normal / penetration inside the constraint
+ reverse and the differentiable narrow phase. The FD scene runs in
+ `requires_grad=True` (forward only) so it produces the *same* contact
+ manifold as the analytical scene; a production-mode scene would take a
+ different (non-diff) narrow-phase path with a different contact set.
+ """
+ init_force = np.broadcast_to(base_force, (n_steps, 6)).copy()
+
+ # --- analytical ---
+ scene_ana, obj_ana = build_fn(requires_grad=True)
+ scene_ana.reset()
+ obj_ana.set_dofs_position(gs.tensor(rest_dofs, dtype=gs.tc_float).sceneless())
+ _settle(scene_ana, obj_ana, n_settle)
+ nc = _n_contacts(scene_ana)
+ assert nc > 0, f"setup error: not in contact after settle (n_contacts={nc})"
+
+ forces = [gs.tensor(init_force[t], dtype=gs.tc_float, requires_grad=True) for t in range(n_steps)]
+ for t in range(n_steps):
+ obj_ana.control_dofs_force(forces[t])
+ scene_ana.step()
+ assert _n_contacts(scene_ana) == nc, "contact set changed during grad window — FD invalid"
+ loss = (_rigid_state(scene_ana).qpos[0, :3] ** 2).sum()
+ scene_ana.backward(loss)
+ ana = np.array([[float(f.grad[d]) for d in range(6)] for f in forces]) # (N, 6)
+
+ # --- central FD, contact set preserved ---
+ scene_fd, obj_fd = build_fn(requires_grad=True)
+
+ def loss_at(perturbed: np.ndarray) -> float:
+ scene_fd.reset()
+ obj_fd.set_dofs_position(gs.tensor(rest_dofs, dtype=gs.tc_float).sceneless())
+ _settle(scene_fd, obj_fd, n_settle)
+ for t in range(n_steps):
+ obj_fd.control_dofs_force(gs.tensor(perturbed[t], dtype=gs.tc_float))
+ scene_fd.step()
+ assert _n_contacts(scene_fd) == nc, "contact set changed under FD perturbation"
+ return float((_rigid_state(scene_fd).qpos[0, :3] ** 2).sum().detach())
+
+ fd = np.full((n_steps, 6), np.nan)
+ for t in range(n_steps):
+ for d in fd_dofs:
+ plus = init_force.copy()
+ plus[t, d] += eps
+ minus = init_force.copy()
+ minus[t, d] -= eps
+ fd[t, d] = (loss_at(plus) - loss_at(minus)) / (2 * eps)
+
+ # Contact gradients are small (stiff contact barely moves), so the band is
+ # absolute-dominated; rtol pins the load-bearing z entry.
+ for t in range(n_steps):
+ for d in fd_dofs:
+ assert_allclose(
+ ana[t, d],
+ fd[t, d],
+ rtol=rtol,
+ atol=atol,
+ err_msg=f"contact force.grad mismatch at t={t}/{n_steps}, dof={d}\nana=\n{ana}\nfd=\n{fd}",
+ )
+
+
+@pytest.mark.required
+@pytest.mark.precision("64")
+@pytest.mark.parametrize(
+ "backend",
+ [
+ gs.cpu,
+ pytest.param(
+ gs.gpu,
+ marks=pytest.mark.skip(
+ reason="General convex-convex (box-box) differentiable contact is not supported on GPU: the "
+ "GPU split (multicontact) narrow phase drops contacts under requires_grad=True (n_contacts=0). "
+ "Only plane-convex is GPU-differentiable (see test_diff_contact_fd_plane_convex). "
+ "Revisit when the split path's diff-contact handling is fixed."
+ ),
+ ),
+ ],
+)
+def test_diff_contact_fd_per_step_force(show_viewer):
+ # Box rests on the ground top (z=0.2) at center z=0.40; settle to a stable
+ # multi-contact manifold, then a short grad window with a per-step push.
+ _run_fd_per_step_force(
+ _build_box_box,
+ [0.0, 0.0, 0.40, 0.0, 0.0, 0.0], # freejoint 6 DOFs: xyz + rotvec(=identity)
+ base_force=np.array([0.0, 0.0, -8.0, 0.0, 0.0, 0.0], dtype=np.float64),
+ n_settle=12,
+ n_steps=2,
+ fd_dofs=(2,),
+ eps=1e-2,
+ rtol=2e-3,
+ atol=1e-10,
+ )
+
+
+# rest z so the body's lowest point sits on the plane (z=0): box/sphere half
+# extent 0.2; capsule radius 0.1 + half_length 0.2 = 0.3 (upright).
+_PLANE_REST = {
+ "box": [0.0, 0.0, 0.20, 0.0, 0.0, 0.0],
+ "sphere": [0.0, 0.0, 0.20, 0.0, 0.0, 0.0],
+ "capsule": [0.0, 0.0, 0.30, 0.0, 0.0, 0.0],
+}
+
+
+@pytest.mark.required
+@pytest.mark.precision("64")
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("shape", ["box", "sphere", "capsule"])
+def test_diff_contact_fd_plane_convex(shape, show_viewer):
+ # Plane (fixed) + free convex. The analytic plane contact is reconstructed
+ # differentiably via `func_differentiable_plane_contact` (stored convex
+ # support core + radius), so the same FD chain as box-box applies.
+ _run_fd_per_step_force(
+ lambda *, requires_grad: _build_plane_convex(shape, requires_grad=requires_grad),
+ _PLANE_REST[shape],
+ base_force=np.array([0.0, 0.0, -8.0, 0.0, 0.0, 0.0], dtype=np.float64),
+ n_settle=12,
+ n_steps=2,
+ fd_dofs=(2,),
+ eps=1e-2,
+ rtol=2e-3,
+ atol=1e-10,
+ )
+
+
+# ===========================================================================
+# Low-level contact-detection + constraint-solver backward FD
+# ===========================================================================
+
+
+@pytest.mark.required
+@pytest.mark.precision("64")
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+def test_diff_contact():
+ RTOL = 1e-4
+
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(
+ dt=0.01,
+ # Turn on differentiable mode
+ requires_grad=True,
+ ),
+ show_viewer=False,
+ )
+
+ box_size = 0.25
+ box_spacing = box_size
+ vec_one = np.array([1.0, 1.0, 1.0])
+ box_pos_offset = (0.0, 0.0, 0.0) + 0.5 * box_size * vec_one
+
+ box0 = scene.add_entity(
+ gs.morphs.Box(size=box_size * vec_one, pos=box_pos_offset),
+ )
+ box1 = scene.add_entity(
+ gs.morphs.Box(size=box_size * vec_one, pos=box_pos_offset + 0.8 * box_spacing * np.array([0, 0, 1])),
+ )
+ scene.build()
+ solver = scene.sim.rigid_solver
+ collider = solver.collider
+
+ # Set up initial configuration
+ x_ang, y_ang, z_ang = 3.0, 3.0, 3.0
+ box1.set_quat(R_to_quat(gs.euler_to_R([np.deg2rad(x_ang), np.deg2rad(y_ang), np.deg2rad(z_ang)])))
+
+ box0_init_pos = box0.get_pos().clone()
+ box1_init_pos = box1.get_pos().clone()
+ box0_init_quat = box0.get_quat().clone()
+ box1_init_quat = box1.get_quat().clone()
+
+ ### Compute the initial loss and compute gradients using differentiable contact detection
+ # Detect contact
+ collider.detection()
+
+ # Get contact outputs and their grads
+ contacts = collider.get_contacts(as_tensor=True, to_torch=True, keep_batch_dim=True)
+ normal = contacts["normal"].requires_grad_()
+ position = contacts["position"].requires_grad_()
+ penetration = contacts["penetration"].requires_grad_()
+
+ loss = ((normal * position).sum(dim=-1) * penetration).sum()
+ dL_dnormal = torch.autograd.grad(loss, normal, retain_graph=True)[0]
+ dL_dposition = torch.autograd.grad(loss, position, retain_graph=True)[0]
+ dL_dpenetration = torch.autograd.grad(loss, penetration)[0]
+
+ # Compute analytical gradients of the geoms position and quaternion
+ collider.backward(dL_dposition, dL_dnormal, dL_dpenetration)
+ dL_dpos = qd_to_torch(solver.geoms_state.pos.grad)
+ dL_dquat = qd_to_torch(solver.geoms_state.quat.grad)
+
+ ### Compute directional derivatives along random directions
+ FD_EPS = 1e-5
+ TRIALS = 100
+
+ def compute_dL_error(dL_dx, x_type):
+ dL_error_rel = 0.0
+
+ box0_input_pos = box0_init_pos
+ box1_input_pos = box1_init_pos
+ box0_input_quat = box0_init_quat
+ box1_input_quat = box1_init_quat
+
+ for _ in range(TRIALS):
+ rand_dx = torch.randn_like(dL_dx)
+ rand_dx = torch.nn.functional.normalize(rand_dx, dim=-1)
+
+ dL = (rand_dx * dL_dx).sum()
+
+ lossPs = []
+ for sign in (1, -1):
+ # Compute query point
+ if x_type == "pos":
+ box0_input_pos = box0_init_pos + sign * rand_dx[0, 0] * FD_EPS
+ box1_input_pos = box1_init_pos + sign * rand_dx[1, 0] * FD_EPS
+ else:
+ # FIXME: The quaternion should be normalized
+ box0_input_quat = box0_init_quat + sign * rand_dx[0, 0] * FD_EPS
+ box1_input_quat = box1_init_quat + sign * rand_dx[1, 0] * FD_EPS
+
+ # Update box positions
+ box0.set_pos(box0_input_pos)
+ box1.set_pos(box1_input_pos)
+ box0.set_quat(box0_input_quat)
+ box1.set_quat(box1_input_quat)
+
+ # Re-detect contact.
+ # We need to manually reset the contact counter as we are not running the whole sim step.
+ collider._collider_state.n_contacts.fill(0)
+ collider.detection()
+ contacts = collider.get_contacts(as_tensor=True, to_torch=True, keep_batch_dim=True)
+ normal, position, penetration = contacts["normal"], contacts["position"], contacts["penetration"]
+
+ # Compute loss
+ loss = ((normal * position).sum(dim=-1) * penetration).sum()
+ lossPs.append(loss)
+
+ dL_fd = (lossPs[0] - lossPs[1]) / (2 * FD_EPS)
+ dL_error_rel += (dL - dL_fd).abs() / max(dL.abs(), dL_fd.abs(), gs.EPS)
+
+ dL_error_rel /= TRIALS
+ return dL_error_rel
+
+ dL_dpos_error_rel = compute_dL_error(dL_dpos, "pos")
+ assert_allclose(dL_dpos_error_rel, 0.0, atol=RTOL)
+ dL_dquat_error_rel = compute_dL_error(dL_dquat, "quat")
+ assert_allclose(dL_dquat_error_rel, 0.0, atol=RTOL)
+
+
+# We need to use 64-bit precision for this test because we need to use sufficiently small perturbation to get reliable
+# gradient estimates through finite difference method. This small perturbation is not supported by 32-bit precision in
+# stable way.
+@pytest.mark.required
+@pytest.mark.precision("64")
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+def test_diff_solver(monkeypatch):
+ from genesis.engine.solvers.rigid.constraint.solver import func_solve_init, func_solve_body
+ from genesis.engine.solvers.rigid.rigid_solver import kernel_step_1
+
+ RTOL = 1e-4
+
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(
+ dt=0.01,
+ requires_grad=True,
+ ),
+ rigid_options=gs.options.RigidOptions(
+ # We use Newton's method because it converges faster than CG, and therefore gives better gradient estimation
+ # when using finite difference method
+ constraint_solver=gs.constraint_solver.Newton,
+ ),
+ show_viewer=False,
+ )
+
+ scene.add_entity(gs.morphs.Plane(pos=(0, 0, 0)))
+ scene.add_entity(gs.morphs.Box(size=(1, 1, 1), pos=(10, 10, 0.49)))
+ franka = scene.add_entity(
+ gs.morphs.MJCF(file="xml/franka_emika_panda/panda.xml"),
+ )
+
+ scene.build()
+ rigid_solver = scene._sim.rigid_solver
+ constraint_solver = rigid_solver.constraint_solver
+
+ franka.set_qpos([-1.0124, 1.5559, 1.3662, -1.6878, -1.5799, 1.7757, 1.4602, 0.04, 0.04])
+
+ # Monkeypatch the constraint resolve function to avoid overwriting the necessary information for computing gradients.
+ def constraint_solver_resolve():
+ func_solve_init(
+ dofs_info=rigid_solver.dofs_info,
+ dofs_state=rigid_solver.dofs_state,
+ entities_info=rigid_solver.entities_info,
+ constraint_state=constraint_solver.constraint_state,
+ rigid_global_info=rigid_solver._rigid_global_info,
+ static_rigid_sim_config=rigid_solver._static_rigid_sim_config,
+ )
+ func_solve_body(
+ entities_info=rigid_solver.entities_info,
+ dofs_info=rigid_solver.dofs_info,
+ dofs_state=rigid_solver.dofs_state,
+ constraint_state=constraint_solver.constraint_state,
+ rigid_global_info=rigid_solver._rigid_global_info,
+ static_rigid_sim_config=rigid_solver._static_rigid_sim_config,
+ _n_iterations=constraint_solver._n_iterations,
+ )
+
+ monkeypatch.setattr(constraint_solver, "resolve", constraint_solver_resolve)
+
+ # Step once to compute constraint solver's inputs: [mass], [jac], [aref], [efc_D], [force]. We do not call the
+ # entire scene.step() because it will overwrite the necessary information that we need to compute the gradients.
+ kernel_step_1(
+ links_state=rigid_solver.links_state,
+ links_info=rigid_solver.links_info,
+ joints_state=rigid_solver.joints_state,
+ joints_info=rigid_solver.joints_info,
+ dofs_state=rigid_solver.dofs_state,
+ dofs_info=rigid_solver.dofs_info,
+ geoms_state=rigid_solver.geoms_state,
+ geoms_info=rigid_solver.geoms_info,
+ entities_state=rigid_solver.entities_state,
+ entities_info=rigid_solver.entities_info,
+ rigid_global_info=rigid_solver._rigid_global_info,
+ static_rigid_sim_config=rigid_solver._static_rigid_sim_config,
+ contact_island_state=constraint_solver.contact_island.contact_island_state,
+ is_forward_pos_updated=True,
+ is_forward_vel_updated=True,
+ is_backward=False,
+ )
+ constraint_solver.add_equality_constraints()
+ rigid_solver.collider.detection()
+ constraint_solver.add_inequality_constraints()
+ constraint_solver.resolve()
+
+ # Loss function to compute gradients using finite difference method
+ def compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force):
+ rigid_solver._rigid_global_info.mass_mat.from_numpy(input_mass)
+ constraint_solver.constraint_state.jac.from_numpy(input_jac)
+ constraint_solver.constraint_state.aref.from_numpy(input_aref)
+ constraint_solver.constraint_state.efc_D.from_numpy(input_efc_D)
+ rigid_solver.dofs_state.force.from_numpy(input_force)
+
+ # Recompute acc_smooth from the updated input variables
+ updated_acc_smooth = np.linalg.solve(input_mass[..., 0], input_force[..., 0])
+ rigid_solver.dofs_state.acc_smooth.from_numpy(updated_acc_smooth[..., None])
+ constraint_solver.resolve()
+
+ output_qacc = qd_to_torch(constraint_solver.qacc)
+ return ((output_qacc - target_qacc) ** 2).mean()
+
+ init_input_mass = qd_to_numpy(rigid_solver._rigid_global_info.mass_mat, copy=True)
+ init_input_jac = qd_to_numpy(constraint_solver.constraint_state.jac, copy=True)
+ init_input_aref = qd_to_numpy(constraint_solver.constraint_state.aref, copy=True)
+ init_input_efc_D = qd_to_numpy(constraint_solver.constraint_state.efc_D, copy=True)
+ init_input_force = qd_to_numpy(rigid_solver.dofs_state.force, copy=True)
+
+ # Initial output of the constraint solver
+ set_random_seed(0)
+ init_output_qacc = qd_to_torch(constraint_solver.qacc)
+ target_qacc = torch.from_numpy(np.random.randn(*init_output_qacc.shape)).to(device=gs.device)
+ target_qacc = target_qacc * init_output_qacc.abs().mean()
+
+ # Solve the constraint solver and get the output
+ output_qacc = qd_to_torch(constraint_solver.qacc, copy=True).requires_grad_(True)
+
+ # Compute loss and gradient of the output
+ loss = ((output_qacc - target_qacc) ** 2).mean()
+ dL_dqacc = tensor_to_array(torch.autograd.grad(loss, output_qacc)[0])
+
+ # Compute gradients of the input variables: [mass], [jac], [aref], [efc_D], [force]
+ constraint_solver.constraint_state.dL_dqacc.from_numpy(dL_dqacc)
+ constraint_solver.backward()
+
+ # Fetch gradients of the input variables
+ dL_dM = qd_to_numpy(constraint_solver.constraint_state.dL_dM)
+ dL_djac = qd_to_numpy(constraint_solver.constraint_state.dL_djac)
+ dL_daref = qd_to_numpy(constraint_solver.constraint_state.dL_daref)
+ dL_defc_D = qd_to_numpy(constraint_solver.constraint_state.dL_defc_D)
+ dL_dforce = qd_to_numpy(constraint_solver.constraint_state.dL_dforce)
+
+ ### Compute directional derivatives along random directions
+ FD_EPS = 1e-3
+ TRIALS = 200
+
+ for dL_dx, x_type in (
+ (dL_dforce, "force"),
+ (dL_daref, "aref"),
+ (dL_defc_D, "efc_D"),
+ (dL_djac, "jac"),
+ (dL_dM, "mass"),
+ ):
+ dL_error = 0.0
+ for _ in range(TRIALS):
+ rand_dx = np.random.randn(*dL_dx.shape)
+ rand_dx = rand_dx / max(
+ np.linalg.norm(rand_dx, axis=0 if x_type in ("force", "aref", "efc_D") else (0, 1)), gs.EPS
+ )
+ if x_type == "mass":
+ # Make rand_dx symmetric
+ rand_dx = (rand_dx + np.moveaxis(rand_dx, 0, 1)) * 0.5
+
+ dL = (rand_dx * dL_dx).sum()
+
+ input_force = init_input_force
+ input_aref = init_input_aref
+ input_efc_D = init_input_efc_D
+ input_jac = init_input_jac
+ input_mass = init_input_mass
+
+ # 1 * eps
+ if x_type == "force":
+ input_force = init_input_force + rand_dx * FD_EPS
+ elif x_type == "aref":
+ input_aref = init_input_aref + rand_dx * FD_EPS
+ elif x_type == "efc_D":
+ input_efc_D = init_input_efc_D + rand_dx * FD_EPS
+ elif x_type == "jac":
+ input_jac = init_input_jac + rand_dx * FD_EPS
+ elif x_type == "mass":
+ input_mass = init_input_mass + rand_dx * FD_EPS
+ lossP1 = compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force)
+
+ # -1 * eps
+ if x_type == "force":
+ input_force = init_input_force - rand_dx * FD_EPS
+ elif x_type == "aref":
+ input_aref = init_input_aref - rand_dx * FD_EPS
+ elif x_type == "efc_D":
+ input_efc_D = init_input_efc_D - rand_dx * FD_EPS
+ elif x_type == "jac":
+ input_jac = init_input_jac - rand_dx * FD_EPS
+ elif x_type == "mass":
+ input_mass = init_input_mass - rand_dx * FD_EPS
+
+ lossP2 = compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force)
+ dL_fd = (lossP1 - lossP2) / (2 * FD_EPS)
+
+ dL_error += (dL - dL_fd).abs() / max(abs(dL), abs(dL_fd), gs.EPS)
+
+ dL_error /= TRIALS
+ assert_allclose(dL_error, 0.0, atol=RTOL)
diff --git a/tests/test_grad_mpm.py b/tests/test_grad_mpm.py
new file mode 100644
index 0000000000..6535a26967
--- /dev/null
+++ b/tests/test_grad_mpm.py
@@ -0,0 +1,101 @@
+"""Differentiable MPM + Tool coupling — gradient-flow sanity.
+
+Separate from the rigid grad suite (test_grad_fd / test_grad_optim /
+test_grad_utils): this exercises the MPM solver and a Tool (rigid stick)
+pushing an elastic MPM box toward a goal, and checks that `loss.backward()`
+routes non-zero gradients to the per-step control inputs (and zero to the
+unused final step). It is a coupling/backward smoke test, not an FD or
+optimization check.
+"""
+
+import pytest
+import torch
+
+import genesis as gs
+
+
+pytestmark = [
+ pytest.mark.debug(False),
+]
+
+
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+def test_differentiable_push(show_viewer):
+ HORIZON = 10
+
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(
+ dt=2e-3,
+ substeps=10,
+ requires_grad=True,
+ ),
+ mpm_options=gs.options.MPMOptions(
+ lower_bound=(0.0, -1.0, 0.0),
+ upper_bound=(1.0, 1.0, 0.55),
+ ),
+ viewer_options=gs.options.ViewerOptions(
+ camera_pos=(2.5, -0.15, 2.42),
+ camera_lookat=(0.5, 0.5, 0.1),
+ ),
+ show_viewer=show_viewer,
+ )
+
+ plane = scene.add_entity(
+ gs.morphs.URDF(
+ file="urdf/plane/plane.urdf",
+ fixed=True,
+ )
+ )
+ stick = scene.add_entity(
+ morph=gs.morphs.Mesh(
+ file="meshes/stirrer.obj",
+ scale=0.6,
+ pos=(0.5, 0.5, 0.05),
+ euler=(90.0, 0.0, 0.0),
+ ),
+ material=gs.materials.Tool(
+ friction=8.0,
+ ),
+ )
+ obj = scene.add_entity(
+ morph=gs.morphs.Box(
+ lower=(0.2, 0.1, 0.05),
+ upper=(0.4, 0.3, 0.15),
+ ),
+ material=gs.materials.MPM.Elastic(
+ rho=500,
+ ),
+ )
+ scene.build(n_envs=2)
+
+ init_pos = gs.tensor([[0.3, 0.1, 0.28], [0.3, 0.1, 0.5]], requires_grad=True)
+ stick.set_position(init_pos)
+ pos_obj_init = gs.tensor([0.3, 0.3, 0.1], requires_grad=True)
+ obj.set_position(pos_obj_init)
+ v_obj_init = gs.tensor([0.0, -1.0, 0.0], requires_grad=True)
+ obj.set_velocity(v_obj_init)
+ goal = gs.tensor([0.5, 0.8, 0.05])
+
+ loss = 0.0
+ v_list = []
+ for i in range(HORIZON):
+ v_i = gs.tensor([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]], requires_grad=True)
+ stick.set_velocity(vel=v_i)
+ v_list.append(v_i)
+
+ scene.step()
+
+ if i == HORIZON // 2:
+ mpm_particles = scene.get_state().solvers_state[scene.solvers.index(scene.mpm_solver)]
+ loss += torch.pow(mpm_particles.pos[mpm_particles.active == 1] - goal, 2).sum()
+
+ if i == HORIZON - 2:
+ state = obj.get_state()
+ loss += torch.pow(state.pos - goal, 2).sum()
+ loss.backward()
+
+ # TODO: It would be great to compare the gradient to its analytical or numerical value.
+ for v_i in v_list[:-1]:
+ assert (v_i.grad.abs() > gs.EPS).any()
+ assert (v_list[-1].grad.abs() < gs.EPS).all()
diff --git a/tests/test_grad_optim.py b/tests/test_grad_optim.py
new file mode 100644
index 0000000000..7ae6059b46
--- /dev/null
+++ b/tests/test_grad_optim.py
@@ -0,0 +1,430 @@
+"""Integration-level optimization convergence for the differentiable rigid solver.
+
+Within the differentiable-rigid test suite, this is the only *end-to-end* check.
+The others verify the gradient is **locally** correct (FD-vs-analytical at a
+point); this one asks whether that gradient is **useful** — does plain Adam,
+driven by the diff-mode backward over a multi-step horizon, actually converge to
+a known answer?
+
+Two optimization targets on the cartpole (contact-free), each recovering a
+final-state target produced by an identical rollout from known inputs:
+ 1. `test_diff_optim_init_vel_cartpole` — Adam on the initial `dofs_velocity`.
+ 2. `test_diff_optim_control_force_cartpole` — Adam on per-step `control_dofs_force`
+Each asserts the per-env loss (a) drops by ≥2 orders of magnitude and (b) ends
+below an absolute threshold — i.e. the backward yields an informative descent
+direction over the horizon, not merely a locally correct gradient.
+
+Parametrized over precision ∈ {fp64, fp32} and n_envs ∈ {0 (single), 4
+(batched, per-env distinct target / init)}. fp32 uses looser thresholds for its
+lower precision floor.
+"""
+
+import os
+import sys
+from pathlib import Path
+
+import numpy as np
+import pytest
+import torch
+
+import genesis as gs
+
+from .utils import assert_allclose
+
+
+_PRECISION_PARAMS = [
+ pytest.param("64", marks=pytest.mark.precision("64"), id="fp64"),
+ pytest.param("32", marks=pytest.mark.precision("32"), id="fp32"),
+]
+
+_N_ENVS_PARAMS = [
+ pytest.param(0, id="single"),
+ pytest.param(4, id="batched"),
+]
+
+
+def _build_scene(*, requires_grad: bool, n_envs: int = 0, substeps: int = 1):
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(
+ dt=0.01,
+ substeps=substeps,
+ gravity=(0.0, 0.0, -9.81),
+ requires_grad=requires_grad,
+ ),
+ rigid_options=gs.options.RigidOptions(
+ enable_collision=False,
+ enable_self_collision=False,
+ enable_joint_limit=False,
+ disable_constraint=True,
+ use_hibernation=False,
+ use_contact_island=False,
+ ),
+ show_viewer=False,
+ )
+ robot = scene.add_entity(gs.morphs.MJCF(file="xml/cartpole.xml"))
+ scene.build(n_envs=n_envs)
+ return scene, robot
+
+
+def _rigid_state(scene):
+ state = scene.get_state()
+ return state.solvers_state[scene.solvers.index(scene.rigid_solver)]
+
+
+def _rollout(scene, robot, init_vel, n_steps):
+ """Apply `init_vel` via `set_dofs_velocity` (the @tracked setter) on a
+ fresh `scene.reset()`, step `n_steps` times, and return the post-rollout
+ (qpos, dofs_vel) tensors from the solver state."""
+ scene.reset()
+ robot.set_dofs_velocity(init_vel)
+ for _ in range(n_steps):
+ scene.step()
+ s = _rigid_state(scene)
+ return s.qpos, s.dofs_vel
+
+
+def _input_shape(n_dofs: int, n_envs: int):
+ return (n_dofs,) if n_envs == 0 else (n_envs, n_dofs)
+
+
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("precision", _PRECISION_PARAMS)
+@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS)
+def test_diff_optim_init_vel_cartpole(show_viewer, n_envs, precision, backend):
+ """Adam on `init_dofs_velocity` (2 params per env: slider + hinge)
+ recovers the per-env final-state target produced by an identical
+ rollout from a known `target_init_vel`. Verifies the diff-mode
+ backward yields an informative descent direction over an N=32 step
+ horizon, with batch independence when n_envs > 0."""
+ N_STEPS = 32
+ N_ITER = 200
+ LR = 1e-2
+ N_DOFS = 2 # cartpole: slider + hinge
+ B = n_envs if n_envs > 0 else 1
+
+ # Precision-specific tolerances. fp32 has ~7 significant digits, so a
+ # rollout-trained scalar loss can't get below ~1e-4 even at the
+ # optimum, and the rate of improvement plateaus earlier.
+ if precision == "64":
+ REL_REDUCTION = 1e-2
+ ABS_THRESHOLD = 1e-4
+ else:
+ REL_REDUCTION = 1e-1
+ ABS_THRESHOLD = 1e-2
+
+ rng = np.random.default_rng(seed=11)
+
+ # --- target trajectory (non-differentiable scene). Per-env distinct
+ # target_init_vel when n_envs > 0, so each env converges to its own
+ # answer.
+ target_init_vel_np = rng.normal(size=_input_shape(N_DOFS, n_envs)) * 0.5
+
+ scene_ref, robot_ref = _build_scene(requires_grad=False, n_envs=n_envs)
+ target_init_vel_t = gs.tensor(target_init_vel_np, dtype=gs.tc_float)
+ with torch.no_grad():
+ target_qpos, target_vel = _rollout(scene_ref, robot_ref, target_init_vel_t, N_STEPS)
+ target_qpos = target_qpos.detach().clone()
+ target_vel = target_vel.detach().clone()
+
+ # --- differentiable scene to optimize on. Per-env distinct init noise.
+ scene_opt, robot_opt = _build_scene(requires_grad=True, n_envs=n_envs)
+ init_offset = rng.normal(size=_input_shape(N_DOFS, n_envs)) * 0.3
+ init_vel_np = target_init_vel_np + init_offset
+ init_vel = gs.tensor(init_vel_np, dtype=gs.tc_float, requires_grad=True)
+
+ opt = torch.optim.Adam([init_vel], lr=LR)
+ loss_history = [] # list of (B,) per-env loss arrays
+
+ for it in range(N_ITER):
+ opt.zero_grad(set_to_none=False)
+ pred_qpos, pred_vel = _rollout(scene_opt, robot_opt, init_vel, N_STEPS)
+ diff_pos = (pred_qpos - target_qpos).reshape(B, -1)
+ diff_vel = (pred_vel - target_vel).reshape(B, -1)
+ loss_per_env = (diff_pos**2).sum(dim=-1) + (diff_vel**2).sum(dim=-1)
+ loss = loss_per_env.sum()
+ loss_history.append(loss_per_env.detach().cpu().numpy().copy())
+ loss.backward()
+ assert init_vel.grad is not None, f"iter {it}: init_vel.grad is None"
+ opt.step()
+
+ history = np.asarray(loss_history) # (N_ITER, B)
+ initial = history[0]
+ final = history[-1]
+
+ # Save per-env loss curves for visual inspection.
+ if not os.environ.get("GENESIS_DIFF_OPTIM_NO_PLOT"):
+ try:
+ import matplotlib
+
+ matplotlib.use("Agg")
+ import matplotlib.pyplot as plt
+
+ backend_name = getattr(backend, "name", str(backend))
+ tag = f"{backend_name}_{precision}_{'batched' if n_envs > 0 else 'single'}"
+ default_out = (
+ Path(__file__).resolve().parent.parent / "runs" / "tmp" / f"diff_optim_init_vel_cartpole_{tag}.png"
+ )
+ out_path = Path(os.environ.get("GENESIS_DIFF_OPTIM_PLOT_PATH", str(default_out)))
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+
+ cmap = plt.get_cmap("tab10")
+ fig, ax = plt.subplots(figsize=(7, 4))
+ for b in range(B):
+ ax.plot(history[:, b], lw=1.2, color=cmap(b % 10), label=f"env{b}")
+ ax.set_yscale("log")
+ ax.set_xlabel("iteration")
+ ax.set_ylabel("loss (log scale)")
+ ax.set_title(
+ f"cartpole init_vel optim [{tag}]: "
+ f"init={initial.max():.2e} → final={final.max():.2e} "
+ f"(worst-env ratio {(final / initial).max():.2e})"
+ )
+ ax.grid(True, which="both", alpha=0.3)
+ if B > 1:
+ ax.legend(loc="upper right", fontsize=8)
+ fig.tight_layout()
+ fig.savefig(str(out_path), dpi=120)
+ plt.close(fig)
+ print(f"\n[diff_optim] loss curve saved to {out_path}")
+ except ImportError:
+ pass
+
+ # Per-env assertions: every env must satisfy both criteria.
+ rel_ratios = final / initial
+ worst_rel_env = int(np.argmax(rel_ratios))
+ assert (rel_ratios < REL_REDUCTION).all(), (
+ f"loss reduction insufficient (worst env={worst_rel_env}): "
+ f"initial={initial[worst_rel_env]:.3e}, final={final[worst_rel_env]:.3e}, "
+ f"ratio={rel_ratios[worst_rel_env]:.3e} (>= {REL_REDUCTION:.0e})"
+ )
+ worst_abs_env = int(np.argmax(final))
+ assert (final < ABS_THRESHOLD).all(), (
+ f"final loss above absolute threshold (worst env={worst_abs_env}): "
+ f"{final[worst_abs_env]:.3e} >= {ABS_THRESHOLD:.0e}"
+ )
+
+
+def _save_loss_plot(history: np.ndarray, *, title_tag: str, plot_name: str):
+ """history: (N_ITER, B). Saves a per-env log-scale loss curve."""
+ if os.environ.get("GENESIS_DIFF_OPTIM_NO_PLOT"):
+ return
+ try:
+ import matplotlib
+
+ matplotlib.use("Agg")
+ import matplotlib.pyplot as plt
+ except ImportError:
+ return
+
+ default_out = Path(__file__).resolve().parent.parent / "runs" / "tmp" / f"{plot_name}.png"
+ out_path = Path(os.environ.get("GENESIS_DIFF_OPTIM_PLOT_PATH", str(default_out)))
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+
+ initial = history[0]
+ final = history[-1]
+ B = history.shape[1]
+ cmap = plt.get_cmap("tab10")
+ fig, ax = plt.subplots(figsize=(7, 4))
+ for b in range(B):
+ ax.plot(history[:, b], lw=1.2, color=cmap(b % 10), label=f"env{b}")
+ ax.set_yscale("log")
+ ax.set_xlabel("iteration")
+ ax.set_ylabel("loss (log scale)")
+ ax.set_title(
+ f"{title_tag}: init={initial.max():.2e} → final={final.max():.2e} "
+ f"(worst-env ratio {(final / initial).max():.2e})"
+ )
+ ax.grid(True, which="both", alpha=0.3)
+ if B > 1:
+ ax.legend(loc="upper right", fontsize=8)
+ fig.tight_layout()
+ fig.savefig(str(out_path), dpi=120)
+ plt.close(fig)
+ print(f"\n[diff_optim] loss curve saved to {out_path}")
+
+
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("precision", _PRECISION_PARAMS)
+@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS)
+def test_diff_optim_control_force_cartpole(show_viewer, n_envs, precision, backend):
+ """Adam on per-step `control_dofs_force` (N_STEPS × n_dofs params per
+ env) recovers the per-env final-state target."""
+ N_STEPS = 32
+ N_ITER = 200
+ LR = 1e-2
+ N_DOFS = 2
+ B = n_envs if n_envs > 0 else 1
+
+ if precision == "64":
+ REL_REDUCTION = 1e-2
+ ABS_THRESHOLD = 1e-4
+ else:
+ REL_REDUCTION = 1e-1
+ ABS_THRESHOLD = 1e-2
+
+ rng = np.random.default_rng(seed=23)
+
+ shape_per_step = _input_shape(N_DOFS, n_envs)
+ target_force_np = rng.normal(size=(N_STEPS,) + shape_per_step) * 0.2
+
+ # --- target trajectory (non-differentiable scene) ---
+ scene_ref, robot_ref = _build_scene(requires_grad=False, n_envs=n_envs)
+ with torch.no_grad():
+ scene_ref.reset()
+ for t in range(N_STEPS):
+ robot_ref.control_dofs_force(gs.tensor(target_force_np[t], dtype=gs.tc_float))
+ scene_ref.step()
+ s = _rigid_state(scene_ref)
+ target_qpos = s.qpos.detach().clone()
+ target_vel = s.dofs_vel.detach().clone()
+
+ # --- differentiable scene + learnable per-step force tensors ---
+ scene_opt, robot_opt = _build_scene(requires_grad=True, n_envs=n_envs)
+ init_offset = rng.normal(size=(N_STEPS,) + shape_per_step) * 0.1
+ init_force_np = target_force_np + init_offset
+ forces = [gs.tensor(init_force_np[t], dtype=gs.tc_float, requires_grad=True) for t in range(N_STEPS)]
+ optimizer = torch.optim.Adam(forces, lr=LR)
+
+ loss_history = []
+ for it in range(N_ITER):
+ optimizer.zero_grad(set_to_none=False)
+ scene_opt.reset()
+ for t in range(N_STEPS):
+ robot_opt.control_dofs_force(forces[t])
+ scene_opt.step()
+ s = _rigid_state(scene_opt)
+ diff_pos = (s.qpos - target_qpos).reshape(B, -1)
+ diff_vel = (s.dofs_vel - target_vel).reshape(B, -1)
+ loss_per_env = (diff_pos**2).sum(dim=-1) + (diff_vel**2).sum(dim=-1)
+ loss = loss_per_env.sum()
+ loss_history.append(loss_per_env.detach().cpu().numpy().copy())
+ loss.backward()
+ for t, f in enumerate(forces):
+ assert f.grad is not None, f"iter {it} step {t}: force.grad is None"
+ optimizer.step()
+
+ history = np.asarray(loss_history) # (N_ITER, B)
+ initial = history[0]
+ final = history[-1]
+
+ backend_name = getattr(backend, "name", str(backend))
+ tag = f"{backend_name}_{precision}_{'batched' if n_envs > 0 else 'single'}"
+ _save_loss_plot(
+ history,
+ title_tag=f"cartpole control_force optim [{tag}]",
+ plot_name=f"diff_optim_control_force_cartpole_{tag}",
+ )
+
+ rel_ratios = final / initial
+ worst_rel_env = int(np.argmax(rel_ratios))
+ assert (rel_ratios < REL_REDUCTION).all(), (
+ f"loss reduction insufficient (worst env={worst_rel_env}): "
+ f"initial={initial[worst_rel_env]:.3e}, final={final[worst_rel_env]:.3e}, "
+ f"ratio={rel_ratios[worst_rel_env]:.3e} (>= {REL_REDUCTION:.0e})"
+ )
+ worst_abs_env = int(np.argmax(final))
+ assert (final < ABS_THRESHOLD).all(), (
+ f"final loss above absolute threshold (worst env={worst_abs_env}): "
+ f"{final[worst_abs_env]:.3e} >= {ABS_THRESHOLD:.0e}"
+ )
+
+
+# ===========================================================================
+# Box pose recovery via Adam (rigid, full scene.step rollout)
+# ===========================================================================
+
+
+@pytest.mark.slow # ~250s
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+def test_differentiable_rigid(show_viewer):
+ dt = 1e-2
+ horizon = 100
+ substeps = 1
+ goal_pos = gs.tensor([0.7, 1.0, 0.05])
+ goal_quat = gs.tensor([0.3, 0.2, 0.1, 0.9])
+ goal_quat = goal_quat / torch.norm(goal_quat, dim=-1, keepdim=True)
+
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(
+ dt=dt,
+ substeps=substeps,
+ requires_grad=True,
+ gravity=(0, 0, -1),
+ ),
+ rigid_options=gs.options.RigidOptions(
+ enable_collision=False,
+ enable_self_collision=False,
+ enable_joint_limit=False,
+ disable_constraint=True,
+ use_contact_island=False,
+ use_hibernation=False,
+ ),
+ viewer_options=gs.options.ViewerOptions(
+ camera_pos=(2.5, -0.15, 2.42),
+ camera_lookat=(0.5, 0.5, 0.1),
+ ),
+ show_viewer=show_viewer,
+ )
+
+ box = scene.add_entity(
+ gs.morphs.Box(
+ pos=(0, 0, 0),
+ size=(0.1, 0.1, 0.2),
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.9, 0.0, 0.0, 1.0),
+ ),
+ )
+ if show_viewer:
+ target = scene.add_entity(
+ gs.morphs.Box(
+ pos=goal_pos,
+ quat=goal_quat,
+ size=(0.1, 0.1, 0.2),
+ ),
+ surface=gs.surfaces.Default(
+ color=(0.0, 0.9, 0.0, 0.5),
+ ),
+ )
+
+ scene.build()
+
+ num_iter = 200
+ lr = 1e-2
+
+ init_pos = gs.tensor([0.3, 0.1, 0.28], requires_grad=True)
+ init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True)
+ optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr)
+
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3)
+
+ for _ in range(num_iter):
+ scene.reset()
+
+ box.set_pos(init_pos)
+ box.set_quat(init_quat)
+
+ loss = 0
+ for _ in range(horizon):
+ scene.step()
+ if show_viewer:
+ target.set_pos(goal_pos)
+ target.set_quat(goal_quat)
+
+ box_state = box.get_state()
+ box_pos = box_state.pos
+ box_quat = box_state.quat
+ loss = torch.abs(box_pos - goal_pos).sum() + torch.abs(box_quat - goal_quat).sum()
+
+ optimizer.zero_grad()
+ loss.backward() # this lets gradient flow all the way back to tensor input
+ optimizer.step()
+ scheduler.step()
+
+ with torch.no_grad():
+ init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True)
+
+ assert_allclose(loss, 0.0, atol=1e-2)
diff --git a/tests/test_grad_utils.py b/tests/test_grad_utils.py
new file mode 100644
index 0000000000..fe7806a06a
--- /dev/null
+++ b/tests/test_grad_utils.py
@@ -0,0 +1,261 @@
+"""Unit tests for the `scene.backward(loss)` API.
+
+Within the differentiable-rigid test suite, this is the *plumbing* layer —
+orthogonal to which physics is active. The other files check that the gradient
+is numerically correct; this one checks that the backward *machinery* (state
+snapshot/restore, gradient-tape clearing, no grad leak across chunked horizons)
+is correct.
+
+`scene.backward(loss)` folds the snapshot → backward → restore dance
+(`scene.get_state()` → `loss.backward()` → `scene.reset(snapshot)`) into a single
+call; the flagship test below exercises it directly.
+
+The flagship test, ``test_horizon_truncation_matches_independent_scenes``,
+runs three scenes in parallel:
+
+ Scene A: single scene, 5-step horizon 1 → ``scene.backward(loss1)``
+ (snapshot + backward + restore in one call) → 5-step horizon 2
+ → ``scene.backward(loss2)``. Yields ``grad1_A`` and ``grad2_A``.
+ Scene B: same as A's horizon 1 only; ``scene.backward`` returns the
+ captured mid-trajectory snapshot. Yields ``grad1_B`` (compared
+ to ``grad1_A``) and that snapshot.
+ Scene C: fresh scene, starts from B's snapshot, runs 5-step horizon 2 →
+ ``scene.backward(loss2)``. Yields ``grad2_C`` (compared to ``grad2_A``).
+
+If `scene.backward(loss)` correctly (a) restores physics state, (b) clears
+the gradient tape, and (c) doesn't leak grad accumulation across horizons,
+then ``grad1_A == grad1_B`` and ``grad2_A == grad2_C`` exactly.
+"""
+
+import sys
+
+import numpy as np
+import pytest
+import torch
+
+import genesis as gs
+from genesis.utils.misc import qd_to_torch
+
+from .utils import assert_allclose
+
+
+pytestmark = [
+ pytest.mark.debug(False),
+]
+
+
+# Parametrization params (mirrors `test_grad_fd.py`).
+_PRECISION_PARAMS = [
+ pytest.param("64", marks=pytest.mark.precision("64"), id="fp64"),
+ pytest.param("32", marks=pytest.mark.precision("32"), id="fp32"),
+]
+
+_N_ENVS_PARAMS = [
+ pytest.param(0, id="single"),
+ pytest.param(4, id="batched"),
+]
+
+_TOL = {
+ "64": dict(atol=1e-12, rtol=1e-10),
+ "32": dict(atol=1e-5, rtol=1e-4),
+}
+
+
+# J1~J5 joint topologies, loaded from the shared `xml/grad/` MJCF assets.
+_TOPOLOGIES = [
+ pytest.param("free", 6, id="J1_free"),
+ pytest.param("revolute", 1, id="J2_revolute"),
+ pytest.param("prismatic", 1, id="J3_prismatic"),
+ pytest.param("free_with_revolute", 7, id="J4_free_rev"),
+ pytest.param("revolute_chain3", 3, id="J5_chain3"),
+]
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _build_scene(model_name: str, n_envs: int = 0, substeps: int = 1):
+ """Build a diff-rigid scene with the standard "no collision / no constraint" config."""
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(
+ dt=0.01,
+ substeps=substeps,
+ gravity=(0.0, 0.0, 0.0),
+ requires_grad=True,
+ ),
+ rigid_options=gs.options.RigidOptions(
+ enable_collision=False,
+ enable_self_collision=False,
+ enable_joint_limit=False,
+ disable_constraint=True,
+ use_hibernation=False,
+ use_contact_island=False,
+ ),
+ show_viewer=False,
+ )
+ robot = scene.add_entity(gs.morphs.MJCF(file=f"xml/grad/{model_name}.xml"))
+ scene.build(n_envs=n_envs)
+ return scene, robot
+
+
+def _make_velocity(n_envs: int, n_dofs: int, seed: int) -> np.ndarray:
+ """Per-env-distinct velocity vector. Single env: shape (n_dofs,). Batched: (n_envs, n_dofs)."""
+ rng = np.random.default_rng(seed)
+ if n_envs == 0:
+ return rng.standard_normal(n_dofs)
+ return rng.standard_normal((n_envs, n_dofs))
+
+
+def _rigid_qpos_loss(scene):
+ """Differentiable scalar loss = sum((qpos)**2). Reads `state.qpos` via
+ `scene.get_state()` so the resulting tensor is a gs.Tensor whose
+ `.backward()` triggers `scene._backward()`."""
+ state = scene.get_state()
+ rigid_state = state.solvers_state[scene.solvers.index(scene.rigid_solver)]
+ return (rigid_state.qpos**2).sum()
+
+
+def _run_segment(scene, robot, v_tensor, n_steps: int):
+ """Apply `set_dofs_velocity(v_tensor)` once, then step `n_steps` times.
+ Returns the resulting (post-step) scalar loss."""
+ robot.set_dofs_velocity(v_tensor)
+ for _ in range(n_steps):
+ scene.step()
+ return _rigid_qpos_loss(scene)
+
+
+def _read_qpos(scene) -> np.ndarray:
+ """Read the simulator's current qpos field (detached)."""
+ solver = scene.rigid_solver
+ return qd_to_torch(solver._rigid_global_info.qpos, copy=True).cpu().numpy()
+
+
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+@pytest.mark.parametrize("precision_str", _PRECISION_PARAMS)
+@pytest.mark.parametrize("substeps", [1, 4])
+@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS)
+@pytest.mark.parametrize("model_name, n_dofs", _TOPOLOGIES)
+def test_horizon_truncation_matches_independent_scenes(model_name, n_dofs, n_envs, substeps, precision_str):
+ """Two-segment trajectory in Scene A matches the same two
+ segments run in independent Scene B (horizon 1) and Scene C (horizon 2,
+ started from B's mid-trajectory snapshot via `scene.reset(state)`).
+
+ Verifies that `scene.get_state()` + `scene.reset(state)` correctly
+ isolates two consecutive horizons: physics state propagates seamlessly,
+ but the autograd tapes are independent."""
+ tol = _TOL[precision_str]
+ rng_v1 = _make_velocity(n_envs, n_dofs, seed=101)
+ rng_v2 = _make_velocity(n_envs, n_dofs, seed=202)
+ H = 5
+
+ # ----- Scene A: one scene, snapshot+reset between two horizons -----
+ sceneA, robotA = _build_scene(model_name, n_envs=n_envs, substeps=substeps)
+ sceneA.reset()
+ v1A = gs.tensor(rng_v1, dtype=gs.tc_float, requires_grad=True)
+ loss_h1_A = _run_segment(sceneA, robotA, v1A, H)
+ qpos_mid_A = _read_qpos(sceneA)
+ # `scene.backward` snapshots the terminal state, runs the backward unroll,
+ # and restores that state — so horizon 2 continues seamlessly from here.
+ sceneA.backward(loss_h1_A)
+ # backward consumes the adstack / input buffer, so the step & substep
+ # counters reset to 0 (they index that buffer) — unlike the physics state,
+ # which is restored. Horizon 2 below thus records a fresh tape from 0.
+ assert sceneA._t == 0 and sceneA._sim._cur_substep_global == 0
+ grad1_A = v1A.grad.detach().clone().cpu().numpy()
+
+ v2A = gs.tensor(rng_v2, dtype=gs.tc_float, requires_grad=True)
+ loss_h2_A = _run_segment(sceneA, robotA, v2A, H)
+ qpos_end_A = _read_qpos(sceneA)
+ sceneA.backward(loss_h2_A)
+ grad2_A = v2A.grad.detach().clone().cpu().numpy()
+
+ # ----- Scene B: same horizon 1 only -----
+ sceneB, robotB = _build_scene(model_name, n_envs=n_envs, substeps=substeps)
+ sceneB.reset()
+ v1B = gs.tensor(rng_v1, dtype=gs.tc_float, requires_grad=True)
+ loss_h1_B = _run_segment(sceneB, robotB, v1B, H)
+ qpos_mid_B = _read_qpos(sceneB)
+ # `scene.backward` returns the terminal snapshot it captured; Scene C below
+ # loads it into a fresh scene via `reset(snapshot_B)`.
+ snapshot_B = sceneB.backward(loss_h1_B)
+ grad1_B = v1B.grad.detach().clone().cpu().numpy()
+
+ # Sanity: A and B end at the same intermediate state and produce the same loss.
+ assert_allclose(qpos_mid_A, qpos_mid_B, atol=0, rtol=0)
+ assert_allclose(loss_h1_A.detach().cpu().item(), loss_h1_B.detach().cpu().item(), atol=0, rtol=0)
+ # Core assertion: horizon-1 gradient identical.
+ assert_allclose(grad1_A, grad1_B, **tol)
+
+ # ----- Scene C: fresh scene, start from B's mid-trajectory snapshot -----
+ sceneC, robotC = _build_scene(model_name, n_envs=n_envs, substeps=substeps)
+ sceneC.reset(snapshot_B)
+ v2C = gs.tensor(rng_v2, dtype=gs.tc_float, requires_grad=True)
+ loss_h2_C = _run_segment(sceneC, robotC, v2C, H)
+ qpos_end_C = _read_qpos(sceneC)
+ sceneC.backward(loss_h2_C)
+ grad2_C = v2C.grad.detach().clone().cpu().numpy()
+
+ # Sanity: A and C end at the same final state and produce the same loss.
+ assert_allclose(qpos_end_A, qpos_end_C, atol=0, rtol=0)
+ assert_allclose(loss_h2_A.detach().cpu().item(), loss_h2_C.detach().cpu().item(), atol=0, rtol=0)
+ # Core assertion: horizon-2 gradient identical.
+ assert_allclose(grad2_A, grad2_C, **tol)
+
+
+# ===========================================================================
+# sim-state vs solver-state gradient parity
+# ===========================================================================
+
+
+@pytest.mark.required
+@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu])
+def test_diff_sim_vs_solver_state_grad_parity(show_viewer):
+ scene = gs.Scene(
+ sim_options=gs.options.SimOptions(
+ dt=0.01,
+ gravity=(0.0, 0.0, 0.0),
+ requires_grad=True,
+ ),
+ rigid_options=gs.options.RigidOptions(
+ enable_collision=False,
+ ),
+ show_viewer=show_viewer,
+ )
+ robot = scene.add_entity(
+ gs.morphs.Box(
+ size=(0.1, 0.1, 0.1),
+ pos=(0, 0, 0),
+ )
+ )
+ scene.build()
+
+ ctrl = gs.tensor(np.random.randn(robot.n_dofs), dtype=gs.tc_float, requires_grad=True)
+
+ grads = []
+ for use_sim_state in (False, True):
+ scene.reset()
+
+ robot.set_dofs_velocity(ctrl)
+ scene.step()
+
+ if use_sim_state:
+ solver_state = scene.get_state().solvers_state[scene.solvers.index(scene.rigid_solver)]
+ chassis_pos = solver_state.links_pos[:, 0].squeeze()
+ else:
+ chassis_pos = robot.get_state().pos.squeeze()
+
+ loss = torch.linalg.norm(chassis_pos)
+ loss.backward()
+ grad = ctrl.grad.detach().clone()
+ ctrl.grad.zero_()
+
+ # Basic sanity check
+ assert (grad[..., :3].abs() > gs.EPS).all()
+ assert (grad[..., 3:].abs() < gs.EPS).all()
+
+ grads.append(grad)
+
+ assert_allclose(*grads, atol=gs.EPS)