diff --git a/genesis/assets/xml/cartpole.xml b/genesis/assets/xml/cartpole.xml new file mode 100644 index 0000000000..d7aba84f55 --- /dev/null +++ b/genesis/assets/xml/cartpole.xml @@ -0,0 +1,15 @@ + + diff --git a/genesis/assets/xml/grad/capsule.xml b/genesis/assets/xml/grad/capsule.xml new file mode 100644 index 0000000000..21fb2cc712 --- /dev/null +++ b/genesis/assets/xml/grad/capsule.xml @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/genesis/assets/xml/grad/free.xml b/genesis/assets/xml/grad/free.xml new file mode 100644 index 0000000000..b2be573454 --- /dev/null +++ b/genesis/assets/xml/grad/free.xml @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/genesis/assets/xml/grad/free_with_revolute.xml b/genesis/assets/xml/grad/free_with_revolute.xml new file mode 100644 index 0000000000..e867e68abc --- /dev/null +++ b/genesis/assets/xml/grad/free_with_revolute.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/genesis/assets/xml/grad/prismatic.xml b/genesis/assets/xml/grad/prismatic.xml new file mode 100644 index 0000000000..d59d4e76ff --- /dev/null +++ b/genesis/assets/xml/grad/prismatic.xml @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/genesis/assets/xml/grad/revolute.xml b/genesis/assets/xml/grad/revolute.xml new file mode 100644 index 0000000000..8221d398a3 --- /dev/null +++ b/genesis/assets/xml/grad/revolute.xml @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/genesis/assets/xml/grad/revolute_chain3.xml b/genesis/assets/xml/grad/revolute_chain3.xml new file mode 100644 index 0000000000..f4ca0bcfb5 --- /dev/null +++ b/genesis/assets/xml/grad/revolute_chain3.xml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + diff --git a/genesis/assets/xml/grad/slider_limit.xml b/genesis/assets/xml/grad/slider_limit.xml new file mode 100644 index 0000000000..1b1a33c142 --- /dev/null +++ b/genesis/assets/xml/grad/slider_limit.xml @@ -0,0 +1,10 @@ + + diff --git a/genesis/assets/xml/grad/spherical.xml b/genesis/assets/xml/grad/spherical.xml new file mode 100644 index 0000000000..2496c44a80 --- /dev/null +++ b/genesis/assets/xml/grad/spherical.xml @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/genesis/assets/xml/hopper.xml b/genesis/assets/xml/hopper.xml new file mode 100644 index 0000000000..63859b2b3f --- /dev/null +++ b/genesis/assets/xml/hopper.xml @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 0db2c18640..523e28d287 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -109,7 +109,7 @@ def __init__( self._load_model() # Initialize target variables and checkpoint - self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity") + self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity", "control_dofs_force") self._tgt = dict() self._tgt_buffer = list() self._ckpt = dict() @@ -1156,6 +1156,8 @@ def process_input(self, in_backward=False): self.set_quat(**data_kwargs) case "set_dofs_velocity": self.set_dofs_velocity(**data_kwargs) + case "control_dofs_force": + self.control_dofs_force(**data_kwargs) case _: gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}") @@ -1188,6 +1190,15 @@ def process_input_grad(self): data_kwargs["dofs_idx_local"], data_kwargs["envs_idx"], ) + + case "control_dofs_force": + force = data_kwargs.pop("force") + if force is not None and force.requires_grad: + force._backward_from_qd( + self.set_dofs_force_grad, + data_kwargs["dofs_idx_local"], + data_kwargs["envs_idx"], + ) case _: gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}") @@ -3559,6 +3570,11 @@ def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, velocity_grad): dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, velocity_grad.data) + @gs.assert_built + def set_dofs_force_grad(self, dofs_idx_local, envs_idx, force_grad): + dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) + self._solver.set_dofs_force_grad(dofs_idx, envs_idx, force_grad.data) + # ------------------------------------------------------------------------------------ # ----------------------------- DOF property setters --------------------------------- # ------------------------------------------------------------------------------------ @@ -3592,6 +3608,7 @@ def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zer # ------------------------------------------------------------------------------------ @gs.assert_built + @tracked def control_dofs_force(self, force, dofs_idx_local=None, envs_idx=None): """ Control the entity's dofs' motor force. This is used for force/torque control. diff --git a/genesis/engine/scene.py b/genesis/engine/scene.py index 7ff57c47f4..65f55f0f66 100644 --- a/genesis/engine/scene.py +++ b/genesis/engine/scene.py @@ -977,13 +977,16 @@ def reset(self, state: SimState | None = None, envs_idx=None): self._reset(state, envs_idx=envs_idx) self._recorder_manager.reset(envs_idx) - def _reset(self, state: SimState | None = None, *, envs_idx=None): + def _reset(self, state: SimState | None = None, *, envs_idx=None, keep_init: bool = False): if self._is_built: if state is None: state = self._init_state else: assert isinstance(state, SimState), "state must be a SimState object" - self._init_state = state + # `keep_init=True` restores the state without making it the new + # init, so a later bare `reset()` still rewinds to the true init. + if not keep_init: + self._init_state = state self._sim.reset(state, envs_idx) else: self._init_state = self._get_state() @@ -1004,6 +1007,49 @@ def _reset(self, state: SimState | None = None, *, envs_idx=None): def _reset_grad(self): self._backward_ready = True + @gs.assert_built + def backward(self, loss: torch.Tensor, *args, **kwargs): + """Differentiate `loss` and restore the terminal physics state. + + Wraps the snapshot/backward/restore dance that differentiable rollouts + otherwise have to perform by hand. `scene._backward()` rewinds physics + state to step 0 as a side-effect of unrolling the adstack, so the safe + pattern is to snapshot the terminal state *before* backward and restore + it *after*: + + snapshot = scene.get_state() # terminal state + loss.backward() # rewinds physics to step 0 + scene.reset(snapshot) # restore + clear grads + re-arm + + This method does exactly that, so callers can just write + `scene.backward(loss)`. Afterwards the scene sits at the terminal physics + state with grads cleared and forward/backward re-armed — ready to continue + the rollout or to be reset to a fresh init. + + The registered initial state (`reset()` with no args) is left untouched. + + Parameters + ---------- + loss : torch.Tensor + Scalar loss to differentiate. Extra args/kwargs (e.g. `gradient`, + `retain_graph`) are forwarded to `torch.autograd.backward`. + """ + # Snapshot the terminal state before backward rewinds physics to step 0. + snapshot = self.get_state() + # `scene._backward()` re-enters the torch graph from each step's queried + # states (`_backward_from_qd` -> `state.backward(retain_graph=True)`), so + # the graph must survive the initial autograd pass. + kwargs.setdefault("retain_graph", True) + # Functional `torch.autograd.backward` fills torch + queried-state grads + # WITHOUT triggering `gs.Tensor.backward`'s auto `scene._backward()`, so + # we drive the sim unroll explicitly below. + torch.autograd.backward(loss, *args, **kwargs) + self._backward() + # Restore to the terminal snapshot; `keep_init=True` preserves the real + # initial state so a later bare `reset()` still rewinds to it. + self._reset(snapshot, keep_init=True) + return snapshot + def _get_state(self): return self._sim.get_state() diff --git a/genesis/engine/solvers/kinematic_solver.py b/genesis/engine/solvers/kinematic_solver.py index 92ae00be0b..1acf7fb84f 100644 --- a/genesis/engine/solvers/kinematic_solver.py +++ b/genesis/engine/solvers/kinematic_solver.py @@ -45,6 +45,7 @@ kernel_set_dofs_position, kernel_set_dofs_velocity, kernel_set_dofs_velocity_grad, + kernel_set_dofs_force_grad, kernel_set_dofs_zero_velocity, kernel_set_links_pos, kernel_set_links_quat, @@ -951,6 +952,14 @@ def set_dofs_velocity_grad(self, dofs_idx, envs_idx, velocity_grad): velocity_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config ) + def set_dofs_force_grad(self, dofs_idx, envs_idx, force_grad): + force_grad_, dofs_idx, envs_idx = self._sanitize_io_variables( + force_grad, dofs_idx, self.n_dofs, "dofs_idx", envs_idx, skip_allocation=True + ) + if self.n_envs == 0: + force_grad_ = force_grad_.unsqueeze(0) + kernel_set_dofs_force_grad(force_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config) + def set_dofs_position(self, position, dofs_idx=None, envs_idx=None): position, dofs_idx, envs_idx = self._sanitize_io_variables( position, dofs_idx, self.n_dofs, "dofs_idx", envs_idx, skip_allocation=True diff --git a/genesis/engine/solvers/rigid/abd/accessor.py b/genesis/engine/solvers/rigid/abd/accessor.py index 672e47ec7c..aab9ae7381 100644 --- a/genesis/engine/solvers/rigid/abd/accessor.py +++ b/genesis/engine/solvers/rigid/abd/accessor.py @@ -787,6 +787,20 @@ def kernel_set_dofs_zero_velocity( dofs_state.vel[dofs_idx[i_d_], envs_idx[i_b_]] = 0.0 +@qd.kernel(fastcache=True) +def kernel_set_dofs_force_grad( + force_grad: qd.types.ndarray(), + dofs_idx: qd.types.ndarray(), + envs_idx: qd.types.ndarray(), + dofs_state: array_class.DofsState, + static_rigid_sim_config: qd.template(), +): + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_d_, i_b_ in qd.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): + force_grad[i_b_, i_d_] = dofs_state.ctrl_force.grad[dofs_idx[i_d_], envs_idx[i_b_]] + dofs_state.ctrl_force.grad[dofs_idx[i_d_], envs_idx[i_b_]] = 0.0 + + @qd.kernel(fastcache=True) def kernel_set_dofs_position( position: qd.types.ndarray(), diff --git a/genesis/engine/solvers/rigid/abd/diff.py b/genesis/engine/solvers/rigid/abd/diff.py index 39b491011f..a872d65344 100644 --- a/genesis/engine/solvers/rigid/abd/diff.py +++ b/genesis/engine/solvers/rigid/abd/diff.py @@ -19,7 +19,7 @@ import genesis as gs import genesis.utils.geom as gu import genesis.utils.array_class as array_class -from .forward_kinematics import func_update_cartesian_space +from .forward_kinematics import func_update_cartesian_space, func_forward_velocity @qd.func @@ -181,6 +181,16 @@ def kernel_prepare_backward_substep( force_update_fixed_geoms=False, is_backward=True, ) + func_forward_velocity( + entities_info=entities_info, + links_info=links_info, + links_state=links_state, + joints_info=joints_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=True, + ) # FIXME: Parameter pruning for ndarray is buggy for now and requires match variable and arg names. # Save results of [update_cartesian_space] to adjoint cache @@ -232,17 +242,17 @@ def kernel_begin_backward_substep( ) if not static_rigid_sim_config.enable_mujoco_compatibility: - # FIXME: Parameter pruning for ndarray is buggy for now and requires match variable and arg names. - # Save results of [update_cartesian_space] to adjoint cache + # Restore the cartesian space that was overwritten by + # post-integrate forward replay in the backward substep func_copy_cartesian_space( - dofs_state=dofs_state, - links_state=links_state, - joints_state=joints_state, - geoms_state=geoms_state, - dofs_state_adjoint_cache=dofs_state_adjoint_cache, - links_state_adjoint_cache=links_state_adjoint_cache, - joints_state_adjoint_cache=joints_state_adjoint_cache, - geoms_state_adjoint_cache=geoms_state_adjoint_cache, + dofs_state=dofs_state_adjoint_cache, + links_state=links_state_adjoint_cache, + joints_state=joints_state_adjoint_cache, + geoms_state=geoms_state_adjoint_cache, + dofs_state_adjoint_cache=dofs_state, + links_state_adjoint_cache=links_state, + joints_state_adjoint_cache=joints_state, + geoms_state_adjoint_cache=geoms_state, static_rigid_sim_config=static_rigid_sim_config, ) @@ -347,6 +357,27 @@ def kernel_copy_acc( dofs_state.acc[i_d, i_b] = rigid_adjoint_cache.dofs_acc[f, i_d, i_b] +@qd.kernel(fastcache=True) +def kernel_copy_next_to_curr_no_check( + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + # Unguarded copy of `_next` slots to current. Used in the backward substep right before + # the forward replay so the BW kernels see the post-integrate qpos / vel. + n_qs = rigid_global_info.qpos.shape[0] + n_dofs = dofs_state.vel.shape[0] + _B = dofs_state.vel.shape[1] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_q, i_b in qd.ndrange(n_qs, _B): + rigid_global_info.qpos[i_q, i_b] = rigid_global_info.qpos_next[i_q, i_b] + + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in qd.ndrange(n_dofs, _B): + dofs_state.vel[i_d, i_b] = dofs_state.vel_next[i_d, i_b] + + @qd.func def func_integrate_dq_entity( dq, diff --git a/genesis/engine/solvers/rigid/abd/forward_dynamics.py b/genesis/engine/solvers/rigid/abd/forward_dynamics.py index b5d7535554..c2e8b30aa3 100644 --- a/genesis/engine/solvers/rigid/abd/forward_dynamics.py +++ b/genesis/engine/solvers/rigid/abd/forward_dynamics.py @@ -36,10 +36,7 @@ def update_qacc_from_qvel_delta( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: qd.template(), - is_backward: qd.template(), ): - BW = qd.static(is_backward) - n_dofs = dofs_state.ctrl_mode.shape[0] _B = dofs_state.ctrl_mode.shape[1] @@ -67,10 +64,7 @@ def update_qvel( dofs_state: array_class.DofsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: qd.template(), - is_backward: qd.template(), ): - BW = qd.static(is_backward) - _B = dofs_state.vel.shape[1] n_dofs = dofs_state.vel.shape[0] @@ -124,7 +118,6 @@ def kernel_compute_mass_matrix( dofs_info=dofs_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=False, ) @@ -163,7 +156,6 @@ def func_forward_dynamics( dofs_info=dofs_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_torque_and_passive_force( entities_state=entities_state, @@ -484,220 +476,135 @@ def func_factor_mass( dofs_info: array_class.DofsInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: qd.template(), - is_backward: qd.template(), ): - BW = qd.static(is_backward) - - if qd.static(not BW): - n_entities = entities_info.n_links.shape[0] - _B = dofs_state.ctrl_mode.shape[1] - - if qd.static( - not static_rigid_sim_config.enable_tiled_cholesky_mass_matrix or static_rigid_sim_config.backend == gs.cpu - ): - qd.loop_config(name="factor_mass", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_e, i_b in qd.ndrange(n_entities, _B): - if rigid_global_info.mass_mat_mask[i_e, i_b]: - entity_dof_start = entities_info.dof_start[i_e] - entity_dof_end = entities_info.dof_end[i_e] - n_dofs = entities_info.n_dofs[i_e] - - for i_d in range(entity_dof_start, entity_dof_end): - for j_d in range(entity_dof_start, i_d + 1): - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b] - - if qd.static(implicit_damping): - I_d = [i_d, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat_L[i_d, i_d, i_b] = ( - rigid_global_info.mass_mat_L[i_d, i_d, i_b] - + dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] - ) - if qd.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): - if dofs_state.ctrl_mode[i_d, i_b] <= gs.CTRL_MODE.VELOCITY: - rigid_global_info.mass_mat_L[i_d, i_d, i_b] = ( - rigid_global_info.mass_mat_L[i_d, i_d, i_b] - - dofs_info.act_bias[I_d][2] * rigid_global_info.substep_dt[None] - ) - - for i_d_ in range(n_dofs): - i_d = entity_dof_end - i_d_ - 1 - D_inv = 1.0 / rigid_global_info.mass_mat_L[i_d, i_d, i_b] - rigid_global_info.mass_mat_D_inv[i_d, i_b] = D_inv - - for j_d_ in range(i_d - entity_dof_start): - j_d = i_d - j_d_ - 1 - a = rigid_global_info.mass_mat_L[i_d, j_d, i_b] * D_inv - for k_d in range(entity_dof_start, j_d + 1): - rigid_global_info.mass_mat_L[j_d, k_d, i_b] -= ( - a * rigid_global_info.mass_mat_L[i_d, k_d, i_b] - ) - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = a - - # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them. - rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0 - else: - BLOCK_DIM = qd.static(32) - MAX_DOFS_PER_ENTITY = qd.static(static_rigid_sim_config.tiled_n_dofs_per_entity) - WARP_SIZE = qd.static(32) - - qd.loop_config(name="factor_mass", block_dim=BLOCK_DIM) - for i in range(n_entities * _B * BLOCK_DIM): - tid = i % BLOCK_DIM - i_e = (i // BLOCK_DIM) % n_entities - i_b = i // (BLOCK_DIM * n_entities) - if i_b >= _B: - continue - - if rigid_global_info.mass_mat_mask[i_e, i_b]: - entity_dof_start = entities_info.dof_start[i_e] - entity_dof_end = entities_info.dof_end[i_e] - n_dofs = entities_info.n_dofs[i_e] - n_lower_tri = n_dofs * (n_dofs + 1) // 2 - - mass_mat = qd.simt.block.SharedArray((MAX_DOFS_PER_ENTITY, MAX_DOFS_PER_ENTITY + 1), gs.qd_float) - - i_pair = tid - while i_pair < n_lower_tri: - i_d_ = qd.cast((qd.sqrt(8 * i_pair + 1) - 1) // 2, qd.i32) - j_d_ = i_pair - i_d_ * (i_d_ + 1) // 2 - i_d = entity_dof_start + i_d_ - j_d = entity_dof_start + j_d_ - mass_mat[i_d_, j_d_] = rigid_global_info.mass_mat[i_d, j_d, i_b] - i_pair = i_pair + BLOCK_DIM - qd.simt.block.sync() - - if qd.static(implicit_damping): - i_d_ = tid - while i_d_ < n_dofs: - i_d = entity_dof_start + i_d_ - I_d = [i_d, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d - mass_mat[i_d_, i_d_] = ( - mass_mat[i_d_, i_d_] + dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] - ) - if qd.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): - if dofs_state.ctrl_mode[i_d, i_b] <= gs.CTRL_MODE.VELOCITY: - mass_mat[i_d_, i_d_] = ( - mass_mat[i_d_, i_d_] - - dofs_info.act_bias[I_d][2] * rigid_global_info.substep_dt[None] - ) - i_d_ = i_d_ + BLOCK_DIM - qd.simt.block.sync() - - for j in range(n_dofs): - i_d_ = n_dofs - j - 1 - i_d = entity_dof_end - j - 1 - - D_inv = 1.0 / mass_mat[i_d_, i_d_] - if tid == 0: - rigid_global_info.mass_mat_D_inv[i_d, i_b] = D_inv - # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them. - rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0 - - j_d_ = i_d_ - 1 - tid - while j_d_ >= 0: - a = mass_mat[i_d_, j_d_] * D_inv - for k_d in range(j_d_ + 1): - mass_mat[j_d_, k_d] = mass_mat[j_d_, k_d] - a * mass_mat[i_d_, k_d] - mass_mat[i_d_, j_d_] = a - j_d_ = j_d_ - BLOCK_DIM - if qd.static(static_rigid_sim_config.backend == gs.cuda): - if i_d_ <= WARP_SIZE: - qd.simt.warp.sync(qd.u32(0xFFFFFFFF)) - else: - qd.simt.block.sync() - else: - qd.simt.block.sync() + n_entities = entities_info.n_links.shape[0] + _B = dofs_state.ctrl_mode.shape[1] - i_pair = tid - n_strict_lower_tri = n_dofs * (n_dofs - 1) // 2 - while i_pair < n_strict_lower_tri: - i_d_ = qd.cast((qd.sqrt(8 * i_pair + 1) + 1) // 2, qd.i32) - j_d_ = i_pair - i_d_ * (i_d_ - 1) // 2 - i_d = entity_dof_start + i_d_ - j_d = entity_dof_start + j_d_ - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = mass_mat[i_d_, j_d_] - i_pair = i_pair + BLOCK_DIM - else: - # Cholesky decomposition that has safe access pattern and robust handling of divide by zero for AD. Even though - # it is logically equivalent to the above block, it shows slightly numerical difference in the result, and thus - # it fails for a unit test ("test_urdf_rope"), while passing all the others. TODO: Investigate if we can fix this - # and only use this block. - - # Assume this is the outermost loop - qd.loop_config(serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)) - for i_e, i_b in qd.ndrange(entities_info.n_links.shape[0], dofs_state.ctrl_mode.shape[1]): + if qd.static( + not static_rigid_sim_config.enable_tiled_cholesky_mass_matrix or static_rigid_sim_config.backend == gs.cpu + ): + qd.loop_config(name="factor_mass", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_e, i_b in qd.ndrange(n_entities, _B): if rigid_global_info.mass_mat_mask[i_e, i_b]: - EPS = rigid_global_info.EPS[None] - entity_dof_start = entities_info.dof_start[i_e] entity_dof_end = entities_info.dof_end[i_e] n_dofs = entities_info.n_dofs[i_e] - for i_d0 in range(n_dofs): - i_d = entity_dof_start + i_d0 - i_pr = (entity_dof_start + entity_dof_end - 1) - i_d + for i_d in range(entity_dof_start, entity_dof_end): for j_d in range(entity_dof_start, i_d + 1): - j_pr = (entity_dof_start + entity_dof_end - 1) - j_d - rigid_global_info.mass_mat_L_bw[0, i_pr, j_pr, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b] - rigid_global_info.mass_mat_L_bw[0, j_pr, i_pr, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b] + rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b] if qd.static(implicit_damping): I_d = [i_d, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d - qd.atomic_add( - rigid_global_info.mass_mat_L_bw[0, i_pr, i_pr, i_b], - (dofs_info.damping[I_d] * rigid_global_info.substep_dt[None]), + rigid_global_info.mass_mat_L[i_d, i_d, i_b] = ( + rigid_global_info.mass_mat_L[i_d, i_d, i_b] + + dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] ) if qd.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): if dofs_state.ctrl_mode[i_d, i_b] <= gs.CTRL_MODE.VELOCITY: - qd.atomic_add( - rigid_global_info.mass_mat_L_bw[0, i_pr, i_pr, i_b], - -dofs_info.act_bias[I_d][2] * rigid_global_info.substep_dt[None], + rigid_global_info.mass_mat_L[i_d, i_d, i_b] = ( + rigid_global_info.mass_mat_L[i_d, i_d, i_b] + - dofs_info.act_bias[I_d][2] * rigid_global_info.substep_dt[None] ) - # Cholesky-Banachiewicz algorithm (in the perturbed indices), access pattern is safe for autodiff - # https://en.wikipedia.org/wiki/Cholesky_decomposition - for p_i0 in range(n_dofs): - for p_j0 in range(p_i0 + 1): - # j_pr <= i_pr - i_pr = entity_dof_start + p_i0 - j_pr = entity_dof_start + p_j0 - - sum = gs.qd_float(0.0) - for p_k0 in range(p_j0): - # k_pr < j_pr - k_pr = entity_dof_start + p_k0 - sum = sum + ( - rigid_global_info.mass_mat_L_bw[1, i_pr, k_pr, i_b] - * rigid_global_info.mass_mat_L_bw[1, j_pr, k_pr, i_b] + for i_d_ in range(n_dofs): + i_d = entity_dof_end - i_d_ - 1 + D_inv = 1.0 / rigid_global_info.mass_mat_L[i_d, i_d, i_b] + rigid_global_info.mass_mat_D_inv[i_d, i_b] = D_inv + + for j_d_ in range(i_d - entity_dof_start): + j_d = i_d - j_d_ - 1 + a = rigid_global_info.mass_mat_L[i_d, j_d, i_b] * D_inv + for k_d in range(entity_dof_start, j_d + 1): + rigid_global_info.mass_mat_L[j_d, k_d, i_b] -= ( + a * rigid_global_info.mass_mat_L[i_d, k_d, i_b] ) + rigid_global_info.mass_mat_L[i_d, j_d, i_b] = a - a = rigid_global_info.mass_mat_L_bw[0, i_pr, j_pr, i_b] - sum - b = qd.math.clamp( - rigid_global_info.mass_mat_L_bw[1, j_pr, j_pr, i_b], - EPS, - qd.math.inf, + # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them. + rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0 + else: + BLOCK_DIM = qd.static(32) + MAX_DOFS_PER_ENTITY = qd.static(static_rigid_sim_config.tiled_n_dofs_per_entity) + WARP_SIZE = qd.static(32) + + qd.loop_config(name="factor_mass", block_dim=BLOCK_DIM) + for i in range(n_entities * _B * BLOCK_DIM): + tid = i % BLOCK_DIM + i_e = (i // BLOCK_DIM) % n_entities + i_b = i // (BLOCK_DIM * n_entities) + if i_b >= _B: + continue + + if rigid_global_info.mass_mat_mask[i_e, i_b]: + entity_dof_start = entities_info.dof_start[i_e] + entity_dof_end = entities_info.dof_end[i_e] + n_dofs = entities_info.n_dofs[i_e] + n_lower_tri = n_dofs * (n_dofs + 1) // 2 + + mass_mat = qd.simt.block.SharedArray((MAX_DOFS_PER_ENTITY, MAX_DOFS_PER_ENTITY + 1), gs.qd_float) + + i_pair = tid + while i_pair < n_lower_tri: + i_d_ = qd.cast((qd.sqrt(8 * i_pair + 1) - 1) // 2, qd.i32) + j_d_ = i_pair - i_d_ * (i_d_ + 1) // 2 + i_d = entity_dof_start + i_d_ + j_d = entity_dof_start + j_d_ + mass_mat[i_d_, j_d_] = rigid_global_info.mass_mat[i_d, j_d, i_b] + i_pair = i_pair + BLOCK_DIM + qd.simt.block.sync() + + if qd.static(implicit_damping): + i_d_ = tid + while i_d_ < n_dofs: + i_d = entity_dof_start + i_d_ + I_d = [i_d, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d + mass_mat[i_d_, i_d_] = ( + mass_mat[i_d_, i_d_] + dofs_info.damping[I_d] * rigid_global_info.substep_dt[None] ) - if p_i0 == p_j0: - rigid_global_info.mass_mat_L_bw[1, i_pr, j_pr, i_b] = qd.sqrt( - qd.math.clamp(a, EPS, qd.math.inf) - ) - else: - rigid_global_info.mass_mat_L_bw[1, i_pr, j_pr, i_b] = a / b + if qd.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): + if dofs_state.ctrl_mode[i_d, i_b] <= gs.CTRL_MODE.VELOCITY: + mass_mat[i_d_, i_d_] = ( + mass_mat[i_d_, i_d_] + - dofs_info.act_bias[I_d][2] * rigid_global_info.substep_dt[None] + ) + i_d_ = i_d_ + BLOCK_DIM + qd.simt.block.sync() - for i_d0 in range(n_dofs): - for i_d1 in range(i_d0 + 1): - i_d = entity_dof_start + i_d0 - j_d = entity_dof_start + i_d1 - i_pr = (entity_dof_start + entity_dof_end - 1) - i_d - j_pr = (entity_dof_start + entity_dof_end - 1) - j_d + for j in range(n_dofs): + i_d_ = n_dofs - j - 1 + i_d = entity_dof_end - j - 1 - a = rigid_global_info.mass_mat_L_bw[1, i_pr, i_pr, i_b] - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat_L_bw[ - 1, j_pr, i_pr, i_b - ] / qd.math.clamp(a, EPS, qd.math.inf) + D_inv = 1.0 / mass_mat[i_d_, i_d_] + if tid == 0: + rigid_global_info.mass_mat_D_inv[i_d, i_b] = D_inv + # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them. + rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0 + + j_d_ = i_d_ - 1 - tid + while j_d_ >= 0: + a = mass_mat[i_d_, j_d_] * D_inv + for k_d in range(j_d_ + 1): + mass_mat[j_d_, k_d] = mass_mat[j_d_, k_d] - a * mass_mat[i_d_, k_d] + mass_mat[i_d_, j_d_] = a + j_d_ = j_d_ - BLOCK_DIM + if qd.static(static_rigid_sim_config.backend == gs.cuda): + if i_d_ <= WARP_SIZE: + qd.simt.warp.sync(qd.u32(0xFFFFFFFF)) + else: + qd.simt.block.sync() + else: + qd.simt.block.sync() - if i_d == j_d: - rigid_global_info.mass_mat_D_inv[i_d, i_b] = 1.0 / (qd.math.clamp(a**2, EPS, qd.math.inf)) + i_pair = tid + n_strict_lower_tri = n_dofs * (n_dofs - 1) // 2 + while i_pair < n_strict_lower_tri: + i_d_ = qd.cast((qd.sqrt(8 * i_pair + 1) + 1) // 2, qd.i32) + j_d_ = i_pair - i_d_ * (i_d_ - 1) // 2 + i_d = entity_dof_start + i_d_ + j_d = entity_dof_start + j_d_ + rigid_global_info.mass_mat_L[i_d, j_d, i_b] = mass_mat[i_d_, j_d_] + i_pair = i_pair + BLOCK_DIM @qd.func @@ -706,55 +613,36 @@ def func_solve_mass_entity( i_b: qd.int32, vec: qd.Tensor, out: qd.Tensor, - out_bw: qd.template(), entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: qd.template(), - is_backward: qd.template(), ): - BW = qd.static(is_backward) - if rigid_global_info.mass_mat_mask[i_e, i_b]: entity_dof_start = entities_info.dof_start[i_e] entity_dof_end = entities_info.dof_end[i_e] n_dofs = entities_info.n_dofs[i_e] - # Step 1: Solve w st. L^T @ w = y + # Step 1: Solve w st. L^T @ w = y. Cross-iter same-buffer read of + # `out[j_d]` (j > i) is fine for forward execution — the entries it + # reads were finalized in earlier (larger-i_d) iterations. The + # reverse of this kernel is never auto-generated; the backward is + # handled by `kernel_manual_compute_qacc_bw` via IFT. for i_d_ in range(n_dofs): i_d = entity_dof_end - i_d_ - 1 curr_out = vec[i_d, i_b] - if qd.static(BW): - out_bw[0, i_d, i_b] = vec[i_d, i_b] - for j_d in range(i_d + 1, entity_dof_end): - # Since we read out[j_d, i_b], and j_d > i_d, which means that out[j_d, i_b] is already - # finalized at this point, we don't need to care about AD mutation rule. - if qd.static(BW): - out_bw[0, i_d, i_b] = ( - out_bw[0, i_d, i_b] - rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out_bw[0, j_d, i_b] - ) - else: - curr_out = curr_out - rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b] - - if qd.static(not BW): - out[i_d, i_b] = curr_out + curr_out = curr_out - rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b] + out[i_d, i_b] = curr_out # Step 2: z = D^{-1} w for i_d in range(entity_dof_start, entity_dof_end): - if qd.static(BW): - out_bw[1, i_d, i_b] = out_bw[0, i_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] - else: - out[i_d, i_b] = out[i_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] + out[i_d, i_b] = out[i_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] # Step 3: Solve x st. L @ x = z for i_d in range(entity_dof_start, entity_dof_end): curr_out = out[i_d, i_b] - if qd.static(BW): - curr_out = out_bw[1, i_d, i_b] - for j_d in range(entity_dof_start, i_d): curr_out = curr_out - rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b] - out[i_d, i_b] = curr_out @@ -763,15 +651,10 @@ def func_solve_mass_batch( i_b: qd.int32, vec: qd.Tensor, out: qd.Tensor, - out_bw: qd.template(), entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: qd.template(), - is_backward: qd.template(), ): - BW = qd.static(is_backward) - - # This loop is considered an inner loop qd.loop_config(serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)) for i_0 in ( range(rigid_global_info.n_awake_entities[i_b]) @@ -779,27 +662,20 @@ def func_solve_mass_batch( else range(entities_info.n_links.shape[0]) ): i_e = rigid_global_info.awake_entities[i_0, i_b] if qd.static(static_rigid_sim_config.use_hibernation) else i_0 - func_solve_mass_entity( - i_e, i_b, vec, out, out_bw, entities_info, rigid_global_info, static_rigid_sim_config, is_backward - ) + func_solve_mass_entity(i_e, i_b, vec, out, entities_info, rigid_global_info, static_rigid_sim_config) @qd.func def func_solve_mass( vec: qd.Tensor, out: qd.Tensor, - out_bw: qd.template(), # None in forward mode, real tensor in backward mode entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: qd.template(), - is_backward: qd.template(), ): - # This loop must be the outermost loop to be differentiable qd.loop_config(serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)) for i_e, i_b in qd.ndrange(entities_info.n_links.shape[0], out.shape[1]): - func_solve_mass_entity( - i_e, i_b, vec, out, out_bw, entities_info, rigid_global_info, static_rigid_sim_config, is_backward - ) + func_solve_mass_entity(i_e, i_b, vec, out, entities_info, rigid_global_info, static_rigid_sim_config) @qd.func @@ -1257,11 +1133,9 @@ def func_compute_qacc( func_solve_mass( vec=dofs_state.force, out=dofs_state.acc_smooth, - out_bw=dofs_state.acc_smooth_bw, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) # Assume this is the outermost loop @@ -1371,11 +1245,7 @@ def func_integrate( dofs_state.vel_next[dof_start + 2, i_b], ] ) - # Backward pass requires atomic add - if qd.static(BW): - qd.atomic_add(pos, vel * rigid_global_info.substep_dt[None]) - else: - pos = pos + vel * rigid_global_info.substep_dt[None] + pos += vel * rigid_global_info.substep_dt[None] for j in qd.static(range(3)): rigid_global_info.qpos_next[q_start + j, i_b] = pos[j] if joint_type == gs.JOINT_TYPE.SPHERICAL or joint_type == gs.JOINT_TYPE.FREE: @@ -1427,6 +1297,10 @@ def kernel_forward_dynamics_without_qacc( contact_island_state: array_class.ContactIslandState, is_backward: qd.template(), ): + # Backward-only kernel. `func_factor_mass` is omitted: its reverse is + # unneeded since `kernel_manual_compute_qacc_bw` seeds `mass_mat.grad` + # directly via IFT (skipping the LDLT factor chain). `func_compute_mass_matrix` + # is kept so Quadrants AD auto-reverses `mass_mat -> links_state.{pos,quat}`. func_compute_mass_matrix( implicit_damping=qd.static(static_rigid_sim_config.integrator == gs.integrator.approximate_implicitfast), links_state=links_state, @@ -1438,15 +1312,6 @@ def kernel_forward_dynamics_without_qacc( static_rigid_sim_config=static_rigid_sim_config, is_backward=is_backward, ) - func_factor_mass( - implicit_damping=False, - entities_info=entities_info, - dofs_state=dofs_state, - dofs_info=dofs_info, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, - ) func_torque_and_passive_force( entities_state=entities_state, entities_info=entities_info, @@ -1539,16 +1404,13 @@ def func_implicit_damping( dofs_info=dofs_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) func_solve_mass( vec=dofs_state.force, out=dofs_state.acc, - out_bw=dofs_state.acc_bw, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, ) # Disable pre-computed factorization mask right away diff --git a/genesis/engine/solvers/rigid/abd/forward_kinematics.py b/genesis/engine/solvers/rigid/abd/forward_kinematics.py index db7b17bfb8..61caf9e438 100644 --- a/genesis/engine/solvers/rigid/abd/forward_kinematics.py +++ b/genesis/engine/solvers/rigid/abd/forward_kinematics.py @@ -638,6 +638,12 @@ def func_forward_kinematics_entity( + joints_state.xaxis[i_j, i_b] * dofs_state.pos[dof_start, i_b] ) pos = W(links_state.pos_bw, next_I, pos_, BW) + # Prismatic doesn't rotate the link, but the per-joint cache (quat_bw) still + # needs `next_I` populated — otherwise the final `R(quat_bw, I_jf, ...)` at the + # end of this function reads uninitialised memory in the backward-mode kernel, + # which surfaces as NaN gradients on qpos. Commit the unchanged quat to the + # next slot so the cache chain is contiguous through the loop iteration. + quat = W(links_state.quat_bw, next_I, quat, BW) # Skip link pose update for fixed root links to let users manually overwrite them I_jf = (i_l, 0 if qd.static(not BW) else n_joints, i_b) @@ -1596,3 +1602,103 @@ def kernel_update_cartesian_space( force_update_fixed_geoms=force_update_fixed_geoms, is_backward=is_backward, ) + + +# --------------------------------------------------------------------------- +# Standalone forward-replay kernels for the `update_cartesian_space` (UCS) +# sub-stages, used only in the backward pass (`substep_pre_coupling_grad`). +# +# The backward unroll replays each UCS sub-step with `is_backward=True` (static +# loops) and then reverses it. The reverse is NOT uniform across sub-steps: +# * COM-links and geom-pose updates reverse cleanly through Quadrants +# autograd, so their replay kernel's `.grad` is called directly. +# * FK and forward-velocity are reversed *manually* (`kernel_manual_forward_kinematics_bw` +# / `kernel_manual_forward_velocity_bw`) because autograd silently drops +# their gradient. +# Splitting UCS into one standalone kernel per stage lets the backward mix +# autograd `.grad` and manual reverses stage-by-stage. +# +# TODO: once every UCS sub-function reverses correctly under autograd, drop +# these standalone kernels (and the manual reverses) and differentiate +# `kernel_update_cartesian_space` as a whole. +# --------------------------------------------------------------------------- +@qd.kernel(fastcache=True) +def kernel_forward_kinematics_replay( + envs_idx: qd.types.ndarray(), + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), + is_backward: qd.template(), +): + for i_b_ in range(envs_idx.shape[0]): + i_b = qd.cast(envs_idx[i_b_], qd.i32) + func_forward_kinematics_batch( + i_b=i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) + + +@qd.kernel(fastcache=True) +def kernel_update_geoms_replay( + entities_info: array_class.EntitiesInfo, + geoms_state: array_class.GeomsState, + geoms_info: array_class.GeomsInfo, + links_state: array_class.LinksState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), + is_backward: qd.template(), +): + func_update_geoms( + entities_info, + geoms_state, + geoms_info, + links_state, + rigid_global_info, + static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=is_backward, + ) + + +@qd.kernel(fastcache=True) +def kernel_COM_links_replay( + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), + is_backward: qd.template(), +): + for i_b in range(links_state.pos.shape[1]): + func_COM_links( + i_b=i_b, + links_state=links_state, + links_info=links_info, + joints_state=joints_state, + joints_info=joints_info, + dofs_state=dofs_state, + dofs_info=dofs_info, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + is_backward=is_backward, + ) diff --git a/genesis/engine/solvers/rigid/abd/manual_bw.py b/genesis/engine/solvers/rigid/abd/manual_bw.py new file mode 100644 index 0000000000..31f9ff97c5 --- /dev/null +++ b/genesis/engine/solvers/rigid/abd/manual_bw.py @@ -0,0 +1,769 @@ +"""Manual reverse-mode kernels for the rigid backward pass. + +Where Quadrants AD silently drops the reverse chain, we compute +the Jacobian-transpose by hand instead. + +Kernels: + - kernel_manual_forward_kinematics_bw : forward-kinematics reverse + (link pos/quat grad -> qpos / dofs_vel grad). + - kernel_manual_forward_velocity_bw : link-velocity propagation reverse. + - kernel_manual_compute_qacc_bw : reverse of `acc = M^{-1} force` via the + Implicit Function Theorem (writes force.grad + mass_mat.grad). + +Hibernation is not supported and sets the `errno` field +(`ErrorCode.MANUAL_BW_UNIMPLEMENTED`) rather than silently corrupting +gradients. + +The `@qd.func` helpers below (`d_transform_by_quat__dq`, `d_quat_mul__dlhs` / +`d_quat_mul__drhs`, `d_rotvec_to_quat__drotvec`, `d_motion_cross_motion`) are +the hand-written chain-rule derivatives of the corresponding +`genesis/utils/geom.py` forward functions, used as building blocks above. +""" + +import quadrants as qd + +import genesis as gs +import genesis.utils.array_class as array_class +import genesis.utils.geom as gu + + +@qd.func +def d_transform_by_quat__dq(v, quat, out_grad): + """Gradient w.r.t. `quat` of `qd_transform_by_quat(v, quat)`. + + Forward (geom.py:294): + out[0] = v0·(qw² + qx² - qy² - qz²) + v1·(2qxy - 2qwz) + v2·(2qxz + 2qwy) + out[1] = v0·(2qxy + 2qwz) + v1·(qw² - qx² + qy² - qz²) + v2·(2qyz - 2qwx) + out[2] = v0·(2qxz - 2qwy) + v1·(2qyz + 2qwx) + v2·(qw² - qx² - qy² + qz²) + + Returns Vec4 = (∂L/∂qw, ∂L/∂qx, ∂L/∂qy, ∂L/∂qz) where + L is whatever scalar seeded `out_grad`. (No normalization assumed.) + """ + qw = quat[0] + qx = quat[1] + qy = quat[2] + qz = quat[3] + v0 = v[0] + v1 = v[1] + v2 = v[2] + og0 = out_grad[0] + og1 = out_grad[1] + og2 = out_grad[2] + + # ∂out[0]/∂{w,x,y,z} + do0_dqw = 2.0 * (qw * v0 - qz * v1 + qy * v2) + do0_dqx = 2.0 * (qx * v0 + qy * v1 + qz * v2) + do0_dqy = 2.0 * (-qy * v0 + qx * v1 + qw * v2) + do0_dqz = 2.0 * (-qz * v0 - qw * v1 + qx * v2) + + # ∂out[1]/∂{w,x,y,z} + do1_dqw = 2.0 * (qz * v0 + qw * v1 - qx * v2) + do1_dqx = 2.0 * (qy * v0 - qx * v1 - qw * v2) + do1_dqy = 2.0 * (qx * v0 + qy * v1 + qz * v2) + do1_dqz = 2.0 * (qw * v0 - qz * v1 + qy * v2) + + # ∂out[2]/∂{w,x,y,z} + do2_dqw = 2.0 * (-qy * v0 + qx * v1 + qw * v2) + do2_dqx = 2.0 * (qz * v0 + qw * v1 - qx * v2) + do2_dqy = 2.0 * (-qw * v0 + qz * v1 - qy * v2) + do2_dqz = 2.0 * (qx * v0 + qy * v1 + qz * v2) + + return qd.Vector( + [ + og0 * do0_dqw + og1 * do1_dqw + og2 * do2_dqw, + og0 * do0_dqx + og1 * do1_dqx + og2 * do2_dqx, + og0 * do0_dqy + og1 * do1_dqy + og2 * do2_dqy, + og0 * do0_dqz + og1 * do1_dqz + og2 * do2_dqz, + ], + dt=gs.qd_float, + ) + + +@qd.func +def d_quat_mul__dlhs(a, b, out_grad): + """Gradient w.r.t. `a` of `quat_mul(a, b)` (Hamilton convention). + + Forward (geom.py qd_quat_mul): + out_w = aw·bw - ax·bx - ay·by - az·bz + out_x = aw·bx + ax·bw + ay·bz - az·by + out_y = aw·by - ax·bz + ay·bw + az·bx + out_z = aw·bz + ax·by - ay·bx + az·bw + """ + bw = b[0] + bx = b[1] + by = b[2] + bz = b[3] + ogw = out_grad[0] + ogx = out_grad[1] + ogy = out_grad[2] + ogz = out_grad[3] + return qd.Vector( + [ + # ∂L/∂aw + ogw * bw + ogx * bx + ogy * by + ogz * bz, + # ∂L/∂ax + -ogw * bx + ogx * bw - ogy * bz + ogz * by, + # ∂L/∂ay + -ogw * by + ogx * bz + ogy * bw - ogz * bx, + # ∂L/∂az + -ogw * bz - ogx * by + ogy * bx + ogz * bw, + ], + dt=gs.qd_float, + ) + + +@qd.func +def d_quat_mul__drhs(a, b, out_grad): + """Gradient w.r.t. `b` of `quat_mul(a, b)`.""" + aw = a[0] + ax = a[1] + ay = a[2] + az = a[3] + ogw = out_grad[0] + ogx = out_grad[1] + ogy = out_grad[2] + ogz = out_grad[3] + return qd.Vector( + [ + # ∂L/∂bw + ogw * aw + ogx * ax + ogy * ay + ogz * az, + # ∂L/∂bx + -ogw * ax + ogx * aw + ogy * az - ogz * ay, + # ∂L/∂by + -ogw * ay - ogx * az + ogy * aw + ogz * ax, + # ∂L/∂bz + -ogw * az + ogx * ay - ogy * ax + ogz * aw, + ], + dt=gs.qd_float, + ) + + +@qd.func +def d_rotvec_to_quat__drotvec(rotvec, eps, quat_grad): + """Gradient w.r.t. `rotvec` of `qd_rotvec_to_quat(rotvec, eps)`. + + Forward: + thetasq = rx² + ry² + rz² + theta_reg = sqrt(thetasq + eps²) + c = cos(theta_reg / 2) + sinc = sin(theta_reg / 2) / theta_reg + quat = (c, sinc·rx, sinc·ry, sinc·rz) + + Backward — by chain rule on theta_reg(rx, ry, rz): + ∂theta_reg/∂ri = ri / theta_reg + ∂c/∂ri = -0.5·sin(theta_reg/2)·ri/theta_reg + = -0.5·(sin·ri)/theta_reg + ∂sinc/∂ri = [(0.5·cos(theta_reg/2))/theta_reg + - sin(theta_reg/2)/theta_reg²] · ri/theta_reg + = ri·(0.5·c/theta_reg² - sinc/theta_reg²) + + ∂quat[0]/∂ri = ∂c/∂ri = -0.5·sin·ri/theta_reg + ∂quat[1+j]/∂ri = ∂(sinc·r_j)/∂ri + = δ(i,j)·sinc + r_j·∂sinc/∂ri + + So rotvec_grad[i] = quat_grad[0]·(-0.5·sin·ri/theta_reg) + + sum_j quat_grad[1+j] · [δ(i,j)·sinc + r_j·∂sinc/∂ri] + = quat_grad[0]·(-0.5·sin·ri/theta_reg) + + sinc·quat_grad[1+i] + + ∂sinc/∂ri · sum_j quat_grad[1+j]·r_j + """ + rx = rotvec[0] + ry = rotvec[1] + rz = rotvec[2] + thetasq = rx * rx + ry * ry + rz * rz + theta_reg = qd.sqrt(thetasq + eps * eps) + theta_half = 0.5 * theta_reg + sin_h = qd.sin(theta_half) + cos_h = qd.cos(theta_half) + sinc = sin_h / theta_reg + # ∂sinc/∂theta_reg = (0.5·cos_h - sinc) / theta_reg + dsinc_dtheta = (0.5 * cos_h - sinc) / theta_reg + + qg_w = quat_grad[0] + qg_x = quat_grad[1] + qg_y = quat_grad[2] + qg_z = quat_grad[3] + + # sum_j quat_grad[1+j] · r_j + qg_dot_r = qg_x * rx + qg_y * ry + qg_z * rz + + # ∂quat[0]/∂ri = -0.5·sin_h·ri/theta_reg + # ∂(sinc·rj)/∂ri = δij·sinc + r_j·(dsinc_dtheta · ri/theta_reg) + # so total per i: + # ri·[ -0.5·sin_h/theta_reg · qg_w + dsinc_dtheta/theta_reg · qg_dot_r ] + sinc·qg_{x,y,z}[i] + coeff = -0.5 * sin_h / theta_reg * qg_w + dsinc_dtheta / theta_reg * qg_dot_r + return qd.Vector( + [ + coeff * rx + sinc * qg_x, + coeff * ry + sinc * qg_y, + coeff * rz + sinc * qg_z, + ], + dt=gs.qd_float, + ) + + +@qd.kernel(fastcache=True) +def kernel_manual_forward_kinematics_bw( + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_state: array_class.JointsState, + joints_info: array_class.JointsInfo, + dofs_info: array_class.DofsInfo, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), + errno: qd.Tensor, +): + """Single-call manual reverse of `kernel_forward_kinematics_replay`. + Iterates each entity's links leaf->root in one launch (so a child's + `parent.{pos,quat}.grad` write lands before the parent consumes it), and + within each link reverses the *full joint chain*. + + A link may carry more than one joint (e.g. a planar floating base = + slide-x + slide-z + hinge-y on one body). The forward composes all of them + in sequence and caches the per-joint intermediate pose in + `links_state.{pos,quat}_bw[i_l, k]`: slot 0 = the "arm base" (parent pose + composed with the link's fixed offset), slot k+1 = pose after joint k, slot + n_joints = the final link pose. We walk those slots in reverse: seed the + grad on the final pose, reverse joint k for k = n_joints-1 .. 0 (each step + consumes the grad on slot k+1, emits qpos.grad for that joint and the grad + on slot k), then reverse the arm-base composition (slot 0) into the parent's + pose grad. + + Each joint also feeds `joints_state.{xanchor,xaxis}` downstream (velocity + FK), so we fold those accumulated grads back through slot k as well. + + Joint types: FREE / REVOLUTE / PRISMATIC / SPHERICAL / FIXED. + """ + qd.loop_config( + name="manual_fk_only_bw", + serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL), + ) + for i_e, i_b in qd.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]): + n_in_e = entities_info.n_links[i_e] + for i_l_rev in range(n_in_e): + i_l = entities_info.link_end[i_e] - 1 - i_l_rev + I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l + parent_idx = links_info.parent_idx[I_l] + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] + + # Grad seeded on the final link pose (= slot n_joints). Carried + # backward through the joint chain; after the loop it holds the grad + # on slot 0 (arm base). + g_pos = links_state.pos.grad[i_l, i_b] + g_quat = links_state.quat.grad[i_l, i_b] + + for k_rev in range(n_joints): + k = n_joints - 1 - k_rev + i_j = links_info.joint_start[I_l] + k + I_j = [i_j, i_b] if qd.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] + q_start = joints_info.q_start[I_j] + dof_start = joints_info.dof_start[I_j] + I_d = [dof_start, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else dof_start + + # Input pose to joint k (slot k), cached by the forward replay. + pos_in = links_state.pos_bw[i_l, k, i_b] + quat_in = links_state.quat_bw[i_l, k, i_b] + joint_pos_off = joints_info.pos[I_j] + xanchor_grad = joints_state.xanchor.grad[i_j, i_b] + xaxis_grad = joints_state.xaxis.grad[i_j, i_b] + + if joint_type == gs.JOINT_TYPE.FREE: + # Final pose is set absolutely from qpos (slot in unused); + # xanchor = qpos[0:3]. + for j in qd.static(range(3)): + rigid_global_info.qpos.grad[q_start + j, i_b] = ( + rigid_global_info.qpos.grad[q_start + j, i_b] + g_pos[j] + xanchor_grad[j] + ) + for j in qd.static(range(4)): + rigid_global_info.qpos.grad[q_start + 3 + j, i_b] = ( + rigid_global_info.qpos.grad[q_start + 3 + j, i_b] + g_quat[j] + ) + g_pos = qd.Vector([0.0, 0.0, 0.0], dt=gs.qd_float) + g_quat = qd.Vector([0.0, 0.0, 0.0, 0.0], dt=gs.qd_float) + + elif joint_type == gs.JOINT_TYPE.REVOLUTE: + axis = dofs_info.motion_ang[I_d] + angle = rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] + rotvec = axis * angle + qloc = gu.qd_rotvec_to_quat(rotvec, rigid_global_info.EPS[None]) + # quat_out = transform_quat_by_quat(qloc, quat_in) = quat_mul(quat_in, qloc) + quat_out = gu.qd_transform_quat_by_quat(qloc, quat_in) + + # pos_out = xanchor - transform(joint_pos_off, quat_out) + # xanchor = transform(joint_pos_off, quat_in) + pos_in + gq_out = g_quat - d_transform_by_quat__dq(joint_pos_off, quat_out, g_pos) + g_qloc = d_quat_mul__drhs(quat_in, qloc, gq_out) + g_quat_in_apply = d_quat_mul__dlhs(quat_in, qloc, gq_out) + rotvec_grad = d_rotvec_to_quat__drotvec(rotvec, rigid_global_info.EPS[None], g_qloc) + angle_grad = axis[0] * rotvec_grad[0] + axis[1] * rotvec_grad[1] + axis[2] * rotvec_grad[2] + rigid_global_info.qpos.grad[q_start, i_b] = rigid_global_info.qpos.grad[q_start, i_b] + angle_grad + + # grad into xanchor = g_pos (from pos_out) + downstream xanchor_grad + g_xanchor = g_pos + xanchor_grad + g_quat_in = ( + g_quat_in_apply + + d_transform_by_quat__dq(joint_pos_off, quat_in, g_xanchor) + + d_transform_by_quat__dq(axis, quat_in, xaxis_grad) + ) + g_pos = g_xanchor + g_quat = g_quat_in + + elif joint_type == gs.JOINT_TYPE.PRISMATIC: + axis = dofs_info.motion_vel[I_d] + displacement = rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] + xaxis = gu.qd_transform_by_quat(axis, quat_in) + # pos_out = pos_in + xaxis * displacement ; quat_out = quat_in + displacement_grad = xaxis[0] * g_pos[0] + xaxis[1] * g_pos[1] + xaxis[2] * g_pos[2] + rigid_global_info.qpos.grad[q_start, i_b] = ( + rigid_global_info.qpos.grad[q_start, i_b] + displacement_grad + ) + g_xaxis = qd.Vector( + [ + g_pos[0] * displacement + xaxis_grad[0], + g_pos[1] * displacement + xaxis_grad[1], + g_pos[2] * displacement + xaxis_grad[2], + ], + dt=gs.qd_float, + ) + g_xanchor = g_pos + xanchor_grad + g_quat_in = ( + g_quat + + d_transform_by_quat__dq(axis, quat_in, g_xaxis) + + d_transform_by_quat__dq(joint_pos_off, quat_in, g_xanchor) + ) + g_pos = g_xanchor + g_quat = g_quat_in + + elif joint_type == gs.JOINT_TYPE.SPHERICAL: + # qloc = qpos[q_start:q_start+4] (direct); quat_out = quat_mul(quat_in, qloc). + # axis defaults to [0,0,1] (xaxis = transform(axis, quat_in)). + axis = qd.Vector([0.0, 0.0, 1.0], dt=gs.qd_float) + qloc = qd.Vector( + [ + rigid_global_info.qpos[q_start, i_b], + rigid_global_info.qpos[q_start + 1, i_b], + rigid_global_info.qpos[q_start + 2, i_b], + rigid_global_info.qpos[q_start + 3, i_b], + ], + dt=gs.qd_float, + ) + quat_out = gu.qd_transform_quat_by_quat(qloc, quat_in) + gq_out = g_quat - d_transform_by_quat__dq(joint_pos_off, quat_out, g_pos) + g_qloc = d_quat_mul__drhs(quat_in, qloc, gq_out) + g_quat_in_apply = d_quat_mul__dlhs(quat_in, qloc, gq_out) + for j in qd.static(range(4)): + rigid_global_info.qpos.grad[q_start + j, i_b] = ( + rigid_global_info.qpos.grad[q_start + j, i_b] + g_qloc[j] + ) + g_xanchor = g_pos + xanchor_grad + g_quat_in = ( + g_quat_in_apply + + d_transform_by_quat__dq(joint_pos_off, quat_in, g_xanchor) + + d_transform_by_quat__dq(axis, quat_in, xaxis_grad) + ) + g_pos = g_xanchor + g_quat = g_quat_in + + else: # gs.JOINT_TYPE.FIXED — pose passes through unchanged. + pass + + for j in qd.static(range(3)): + joints_state.xanchor.grad[i_j, i_b][j] = 0.0 + joints_state.xaxis.grad[i_j, i_b][j] = 0.0 + + # Reverse the arm-base composition (slot 0): + # arm_base_pos = parent_pos + transform(link_offset_pos, parent_quat) + # arm_base_quat = quat_mul(parent_quat, link_offset_quat) + # propagating slot-0 grad (g_pos, g_quat) into the parent's pose grad. + if parent_idx != -1: + parent_quat = links_state.quat[parent_idx, i_b] + link_off_pos = links_info.pos[I_l] + link_off_quat = links_info.quat[I_l] + parent_quat_grad_from_pos = d_transform_by_quat__dq(link_off_pos, parent_quat, g_pos) + parent_quat_grad_from_quat = d_quat_mul__dlhs(parent_quat, link_off_quat, g_quat) + for j in qd.static(range(3)): + links_state.pos.grad[parent_idx, i_b][j] = links_state.pos.grad[parent_idx, i_b][j] + g_pos[j] + for j in qd.static(range(4)): + links_state.quat.grad[parent_idx, i_b][j] = ( + links_state.quat.grad[parent_idx, i_b][j] + + parent_quat_grad_from_pos[j] + + parent_quat_grad_from_quat[j] + ) + + for j in qd.static(range(3)): + links_state.pos.grad[i_l, i_b][j] = 0.0 + for j in qd.static(range(4)): + links_state.quat.grad[i_l, i_b][j] = 0.0 + + +@qd.func +def d_motion_cross_motion(s_ang, s_vel, m_ang, m_vel, ang_g, vel_g): + """Reverse of motion_cross_motion(s_ang, s_vel, m_ang, m_vel). + + Forward (geom.py:437): + vel = s_ang × m_vel + s_vel × m_ang + ang = s_ang × m_ang + + Chain rule (c=a×b ⇒ a.g += b × c.g, b.g += c.g × a): + s_ang.g += m_ang × ang.g + m_vel × vel.g + s_vel.g += m_ang × vel.g + m_ang.g += ang.g × s_ang + vel.g × s_vel + m_vel.g += vel.g × s_ang + + Returns (s_ang_g, s_vel_g, m_ang_g, m_vel_g) — additive deltas. + """ + return ( + m_ang.cross(ang_g) + m_vel.cross(vel_g), + m_ang.cross(vel_g), + ang_g.cross(s_ang) + vel_g.cross(s_vel), + vel_g.cross(s_ang), + ) + + +@qd.kernel(fastcache=True) +def kernel_manual_forward_velocity_bw( + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), + errno: qd.Tensor, +): + """Manual reverse of `kernel_forward_velocity` — single-call (no per-link + split). Replaces the diagnostic per-link split in `substep_pre_coupling_grad` + by computing the cross-link `cd_{vel,ang}[parent_idx]` chain explicitly. + + Inputs (read .grad seeds): + - cd_vel.grad[i_l, i_b], cd_ang.grad[i_l, i_b] + - cd_vel_bw.grad[i_l, k, i_b], cd_ang_bw.grad[i_l, k, i_b] + - cdofd_ang.grad[i_d, i_b], cdofd_vel.grad[i_d, i_b] + + Outputs (accumulated .grad): + - dofs_state.vel.grad[i_d, i_b] + - dofs_state.cdof_ang.grad[i_d, i_b], dofs_state.cdof_vel.grad[i_d, i_b] + - links_state.cd_vel.grad[parent_idx, i_b], links_state.cd_ang.grad[parent_idx, i_b] + (cross-link chain — equivalent to forward replay's BW=True + `cd_*_bw[i_l, 0] = parent.cd_*`) + """ + qd.loop_config( + name="manual_forward_velocity_bw", + serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL), + ) + for i_e, i_b in qd.ndrange(entities_info.n_links.shape[0], links_state.pos.shape[1]): + if qd.static(static_rigid_sim_config.use_hibernation): + errno[i_b] = errno[i_b] | array_class.ErrorCode.MANUAL_BW_UNIMPLEMENTED + else: + n_in_e = entities_info.n_links[i_e] + # Leaf → root iteration so each link's cd_*_bw[0].grad (which + # accumulates into parent.cd_*.grad) is propagated *before* the + # parent's own iteration uses it. + for i_l_rev in range(n_in_e): + i_l = entities_info.link_end[i_e] - 1 - i_l_rev + I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l + n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l] + i_p = links_info.parent_idx[I_l] + + # ── Step 1 reverse: cd_*[i_l].grad → cd_*_bw[i_l, n_joints].grad + for k in qd.static(range(3)): + links_state.cd_vel_bw.grad[i_l, n_joints, i_b][k] = ( + links_state.cd_vel_bw.grad[i_l, n_joints, i_b][k] + links_state.cd_vel.grad[i_l, i_b][k] + ) + links_state.cd_ang_bw.grad[i_l, n_joints, i_b][k] = ( + links_state.cd_ang_bw.grad[i_l, n_joints, i_b][k] + links_state.cd_ang.grad[i_l, i_b][k] + ) + # consume cd_vel/cd_ang.grad[i_l] + for k in qd.static(range(3)): + links_state.cd_vel.grad[i_l, i_b][k] = 0.0 + links_state.cd_ang.grad[i_l, i_b][k] = 0.0 + + # ── Step 2: iterate joints in reverse + for i_j_rev in range(n_joints): + i_j_ = n_joints - 1 - i_j_rev + i_j = i_j_ + links_info.joint_start[I_l] + I_j = [i_j, i_b] if qd.static(static_rigid_sim_config.batch_joints_info) else i_j + jt = joints_info.type[I_j] + ds = joints_info.dof_start[I_j] + de = joints_info.dof_end[I_j] + curr_idx = i_j_ + next_idx = i_j_ + 1 + + # ── [d-rev] cd_*_bw[next].grad → cdof_*.grad / vel.grad + # Forward (FREE angular: i_3=0..2 at d=ds+3+i_3; else: d in ds..de): + # _vel = cdof_vel[d] * vel[d]; atomic_add(cd_vel_bw[next], _vel) + # _ang = cdof_ang[d] * vel[d]; atomic_add(cd_ang_bw[next], _ang) + cvg_next = links_state.cd_vel_bw.grad[i_l, next_idx, i_b] + cag_next = links_state.cd_ang_bw.grad[i_l, next_idx, i_b] + if jt == gs.JOINT_TYPE.FREE: + for i_3 in qd.static(range(3)): + d_i = ds + 3 + i_3 + v_at_d = dofs_state.vel[d_i, i_b] + cdv = dofs_state.cdof_vel[d_i, i_b] + cda = dofs_state.cdof_ang[d_i, i_b] + for k in qd.static(range(3)): + dofs_state.cdof_vel.grad[d_i, i_b][k] = ( + dofs_state.cdof_vel.grad[d_i, i_b][k] + cvg_next[k] * v_at_d + ) + dofs_state.cdof_ang.grad[d_i, i_b][k] = ( + dofs_state.cdof_ang.grad[d_i, i_b][k] + cag_next[k] * v_at_d + ) + dot_vel = cdv[0] * cvg_next[0] + cdv[1] * cvg_next[1] + cdv[2] * cvg_next[2] + dot_ang = cda[0] * cag_next[0] + cda[1] * cag_next[1] + cda[2] * cag_next[2] + dofs_state.vel.grad[d_i, i_b] = dofs_state.vel.grad[d_i, i_b] + dot_vel + dot_ang + else: + for i_d in range(ds, de): + v_at_d = dofs_state.vel[i_d, i_b] + cdv = dofs_state.cdof_vel[i_d, i_b] + cda = dofs_state.cdof_ang[i_d, i_b] + for k in qd.static(range(3)): + dofs_state.cdof_vel.grad[i_d, i_b][k] = ( + dofs_state.cdof_vel.grad[i_d, i_b][k] + cvg_next[k] * v_at_d + ) + dofs_state.cdof_ang.grad[i_d, i_b][k] = ( + dofs_state.cdof_ang.grad[i_d, i_b][k] + cag_next[k] * v_at_d + ) + dot_vel = cdv[0] * cvg_next[0] + cdv[1] * cvg_next[1] + cdv[2] * cvg_next[2] + dot_ang = cda[0] * cag_next[0] + cda[1] * cag_next[1] + cda[2] * cag_next[2] + dofs_state.vel.grad[i_d, i_b] = dofs_state.vel.grad[i_d, i_b] + dot_vel + dot_ang + + # ── [c-rev] cd_*_bw[next] = cd_*_bw[curr] → curr.grad += next.grad + for k in qd.static(range(3)): + links_state.cd_vel_bw.grad[i_l, curr_idx, i_b][k] = ( + links_state.cd_vel_bw.grad[i_l, curr_idx, i_b][k] + cvg_next[k] + ) + links_state.cd_ang_bw.grad[i_l, curr_idx, i_b][k] = ( + links_state.cd_ang_bw.grad[i_l, curr_idx, i_b][k] + cag_next[k] + ) + # consume next + for k in qd.static(range(3)): + links_state.cd_vel_bw.grad[i_l, next_idx, i_b][k] = 0.0 + links_state.cd_ang_bw.grad[i_l, next_idx, i_b][k] = 0.0 + + # ── [b-rev] motion_cross_motion reverse: + # Forward: (cdofd_ang[d_i], cdofd_vel[d_i]) = + # motion_cross_motion(cd_ang_bw[curr], cd_vel_bw[curr], cdof_ang[d_i], cdof_vel[d_i]) + # Reverse via d_motion_cross_motion(s_ang, s_vel, m_ang, m_vel, ang_g, vel_g) + s_ang_primal = links_state.cd_ang_bw[i_l, curr_idx, i_b] + s_vel_primal = links_state.cd_vel_bw[i_l, curr_idx, i_b] + if jt == gs.JOINT_TYPE.FREE: + # Angular dofs i_3=0..2 at d_i = ds + 3 + i_3 (linear cdofd_* are explicit 0) + for i_3 in qd.static(range(3)): + d_i = ds + 3 + i_3 + ang_g = dofs_state.cdofd_ang.grad[d_i, i_b] + vel_g = dofs_state.cdofd_vel.grad[d_i, i_b] + cda = dofs_state.cdof_ang[d_i, i_b] + cdv = dofs_state.cdof_vel[d_i, i_b] + s_ang_g, s_vel_g, m_ang_g, m_vel_g = d_motion_cross_motion( + s_ang_primal, s_vel_primal, cda, cdv, ang_g, vel_g + ) + for k in qd.static(range(3)): + links_state.cd_ang_bw.grad[i_l, curr_idx, i_b][k] = ( + links_state.cd_ang_bw.grad[i_l, curr_idx, i_b][k] + s_ang_g[k] + ) + links_state.cd_vel_bw.grad[i_l, curr_idx, i_b][k] = ( + links_state.cd_vel_bw.grad[i_l, curr_idx, i_b][k] + s_vel_g[k] + ) + dofs_state.cdof_ang.grad[d_i, i_b][k] = ( + dofs_state.cdof_ang.grad[d_i, i_b][k] + m_ang_g[k] + ) + dofs_state.cdof_vel.grad[d_i, i_b][k] = ( + dofs_state.cdof_vel.grad[d_i, i_b][k] + m_vel_g[k] + ) + # consume cdofd_*.grad[d_i] + for k in qd.static(range(3)): + dofs_state.cdofd_ang.grad[d_i, i_b][k] = 0.0 + dofs_state.cdofd_vel.grad[d_i, i_b][k] = 0.0 + # Linear dofs (i_3=0..2 at d_i = ds + i_3): cdofd_* set to 0 + # (constant), reverse is no-op; just consume to mirror P8. + for i_3 in qd.static(range(3)): + d_i = ds + i_3 + for k in qd.static(range(3)): + dofs_state.cdofd_ang.grad[d_i, i_b][k] = 0.0 + dofs_state.cdofd_vel.grad[d_i, i_b][k] = 0.0 + else: + for i_d in range(ds, de): + ang_g = dofs_state.cdofd_ang.grad[i_d, i_b] + vel_g = dofs_state.cdofd_vel.grad[i_d, i_b] + cda = dofs_state.cdof_ang[i_d, i_b] + cdv = dofs_state.cdof_vel[i_d, i_b] + s_ang_g, s_vel_g, m_ang_g, m_vel_g = d_motion_cross_motion( + s_ang_primal, s_vel_primal, cda, cdv, ang_g, vel_g + ) + for k in qd.static(range(3)): + links_state.cd_ang_bw.grad[i_l, curr_idx, i_b][k] = ( + links_state.cd_ang_bw.grad[i_l, curr_idx, i_b][k] + s_ang_g[k] + ) + links_state.cd_vel_bw.grad[i_l, curr_idx, i_b][k] = ( + links_state.cd_vel_bw.grad[i_l, curr_idx, i_b][k] + s_vel_g[k] + ) + dofs_state.cdof_ang.grad[i_d, i_b][k] = ( + dofs_state.cdof_ang.grad[i_d, i_b][k] + m_ang_g[k] + ) + dofs_state.cdof_vel.grad[i_d, i_b][k] = ( + dofs_state.cdof_vel.grad[i_d, i_b][k] + m_vel_g[k] + ) + for k in qd.static(range(3)): + dofs_state.cdofd_ang.grad[i_d, i_b][k] = 0.0 + dofs_state.cdofd_vel.grad[i_d, i_b][k] = 0.0 + + # ── [a-rev] (FREE only) cd_*_bw[curr].grad → linear cdof_*.grad / vel.grad + # Forward (FREE linear pre-motion_cross_motion): for i_3=0..2 at d_i = ds + i_3, + # _vel = cdof_vel[d_i] * vel[d_i]; atomic_add(cd_vel_bw[curr], _vel) + # _ang = cdof_ang[d_i] * vel[d_i]; atomic_add(cd_ang_bw[curr], _ang) + # (cdof_vel[linear] = e_i_3 constant; cdof_ang[linear] = 0 constant) + if jt == gs.JOINT_TYPE.FREE: + cvg_curr = links_state.cd_vel_bw.grad[i_l, curr_idx, i_b] + cag_curr = links_state.cd_ang_bw.grad[i_l, curr_idx, i_b] + for i_3 in qd.static(range(3)): + d_i = ds + i_3 + v_at_d = dofs_state.vel[d_i, i_b] + cdv = dofs_state.cdof_vel[d_i, i_b] + cda = dofs_state.cdof_ang[d_i, i_b] + for k in qd.static(range(3)): + dofs_state.cdof_vel.grad[d_i, i_b][k] = ( + dofs_state.cdof_vel.grad[d_i, i_b][k] + cvg_curr[k] * v_at_d + ) + dofs_state.cdof_ang.grad[d_i, i_b][k] = ( + dofs_state.cdof_ang.grad[d_i, i_b][k] + cag_curr[k] * v_at_d + ) + dot_vel = cdv[0] * cvg_curr[0] + cdv[1] * cvg_curr[1] + cdv[2] * cvg_curr[2] + dot_ang = cda[0] * cag_curr[0] + cda[1] * cag_curr[1] + cda[2] * cag_curr[2] + dofs_state.vel.grad[d_i, i_b] = dofs_state.vel.grad[d_i, i_b] + dot_vel + dot_ang + + # ── Step 1 (initial cvel setup) reverse: + # Forward: cd_*_bw[i_l, 0, i_b] = parent.cd_*[i_p, i_b] (if i_p != -1) else 0 + # Reverse: parent.cd_*.grad[i_p] += cd_*_bw[i_l, 0].grad; consume slot 0 + slot0_v_g = links_state.cd_vel_bw.grad[i_l, 0, i_b] + slot0_a_g = links_state.cd_ang_bw.grad[i_l, 0, i_b] + if i_p != -1: + for k in qd.static(range(3)): + links_state.cd_vel.grad[i_p, i_b][k] = links_state.cd_vel.grad[i_p, i_b][k] + slot0_v_g[k] + links_state.cd_ang.grad[i_p, i_b][k] = links_state.cd_ang.grad[i_p, i_b][k] + slot0_a_g[k] + # consume slot 0 + for k in qd.static(range(3)): + links_state.cd_vel_bw.grad[i_l, 0, i_b][k] = 0.0 + links_state.cd_ang_bw.grad[i_l, 0, i_b][k] = 0.0 + + +@qd.kernel(fastcache=True) +def kernel_manual_compute_qacc_bw( + dofs_state: array_class.DofsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Manual backward for `func_compute_qacc` via Implicit Function Theorem. + + Forward chain (`func_compute_qacc`): + acc_smooth = M^{-1} . force (via LDLT solve in `func_solve_mass`) + acc[i] = acc_smooth[i] (identity copy) + + Reverse chain (manual, by IFT and symmetry of M = L^T D L): + acc_smooth.grad += acc.grad (reverse of identity copy) + acc.grad = 0 (forward overwrites acc) + force_contrib = M^{-1} . acc_smooth.grad + force.grad += force_contrib + (M is symmetric, so M^{-T} = M^{-1}) + # Symmetric lower-tri storage of mass_mat: forward only reads + # mass_mat[i, j] for i >= j (lower-tri). Each off-diagonal parameter + # appears once in the forward, but the implicit symmetry means the + # IFT contribution to the lower-tri parameter combines both `(i, j)` + # and `(j, i)` chain terms (the upper mirror is logically present + # via symmetry of the matrix being inverted). + mass_mat[i, i].grad += -force_contrib[i] * acc_smooth[i] + mass_mat[i, j].grad += -(force_contrib[i] * acc_smooth[j] + + force_contrib[j] * acc_smooth[i]) (i > j) + + LDLT structure (Genesis transposed convention M = L^T D L): + Step 1: solve L^T . u = acc_smooth.grad (descending i_d) + Step 2: v = D^{-1} . u + Step 3: solve L . delta_force_grad = v (ascending i_d) + + Inputs vs outputs (note the asymmetry between the values *read* from + `rigid_global_info` and the grad fields *written* into it): + Reads: + - dofs_state.acc.grad, dofs_state.acc_smooth.grad (seed) + - rigid_global_info.mass_mat_L (LDLT solve) + - rigid_global_info.mass_mat_D_inv (LDLT solve) + - dofs_state.acc_smooth (IFT outer product) + Writes: + - dofs_state.force.grad (M^{-1} . seed) + - rigid_global_info.mass_mat.grad (IFT seed) + - dofs_state.acc.grad, dofs_state.acc_smooth.grad (consumed → 0) + + The dense `mass_mat` is *not* read here (the forward already factored it + into `mass_mat_L` / `mass_mat_D_inv`), but its `.grad` is the parameter + the IFT chain naturally exposes, so this kernel is the only place + `mass_mat.grad` gets populated in the backward path. + """ + qd.loop_config(serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)) + for i_e, i_b in qd.ndrange(entities_info.n_links.shape[0], dofs_state.force.shape[1]): + if rigid_global_info.mass_mat_mask[i_e, i_b]: + entity_dof_start = entities_info.dof_start[i_e] + entity_dof_end = entities_info.dof_end[i_e] + n_dofs = entities_info.n_dofs[i_e] + + # Reverse of `acc[i] = acc_smooth[i]`: drain acc.grad → seed acc_smooth.grad. + # Stash the combined seed in `acc_smooth_bw[0]` as the input to the LDLT + # reverse solve. Then zero `acc.grad` since the forward copy overwrites + # `acc` (its old value is destroyed, so any prior `acc.grad` is consumed). + for i_d in range(entity_dof_start, entity_dof_end): + dofs_state.acc_smooth_bw[0, i_d, i_b] = ( + dofs_state.acc_smooth.grad[i_d, i_b] + dofs_state.acc.grad[i_d, i_b] + ) + dofs_state.acc.grad[i_d, i_b] = 0.0 + dofs_state.acc_smooth.grad[i_d, i_b] = 0.0 + + # Step 1: solve L^T . u = seed (input from [0], output to [1]) + # u[i] = seed[i] - sum_{j>i} L[j,i] * u[j] + for i_d_ in range(n_dofs): + i_d = entity_dof_end - i_d_ - 1 + curr = dofs_state.acc_smooth_bw[0, i_d, i_b] + for j_d in range(i_d + 1, entity_dof_end): + curr = curr - rigid_global_info.mass_mat_L[j_d, i_d, i_b] * dofs_state.acc_smooth_bw[1, j_d, i_b] + dofs_state.acc_smooth_bw[1, i_d, i_b] = curr + + # Step 2: v = D^{-1} . u (output to [0], overwriting input) + for i_d in range(entity_dof_start, entity_dof_end): + dofs_state.acc_smooth_bw[0, i_d, i_b] = ( + dofs_state.acc_smooth_bw[1, i_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] + ) + + # Step 3: solve L . delta = v (input from [0], output to [1]) + # delta[i] = v[i] - sum_{j j): ∂L/∂M_ij = -(force_contrib[i] · acc_smooth[j] + # + force_contrib[j] · acc_smooth[i]) + # The off-diagonal sum picks up the chain through both the + # (i, j) and (j, i) occurrences of the parameter in the symmetric + # matrix. + for i_d in range(entity_dof_start, entity_dof_end): + fi = dofs_state.acc_smooth_bw[1, i_d, i_b] + ai = dofs_state.acc_smooth[i_d, i_b] + rigid_global_info.mass_mat.grad[i_d, i_d, i_b] = ( + rigid_global_info.mass_mat.grad[i_d, i_d, i_b] - fi * ai + ) + for j_d in range(entity_dof_start, i_d): + fj = dofs_state.acc_smooth_bw[1, j_d, i_b] + aj = dofs_state.acc_smooth[j_d, i_b] + rigid_global_info.mass_mat.grad[i_d, j_d, i_b] = rigid_global_info.mass_mat.grad[i_d, j_d, i_b] - ( + fi * aj + fj * ai + ) diff --git a/genesis/engine/solvers/rigid/collider/collider.py b/genesis/engine/solvers/rigid/collider/collider.py index 6a0ad21112..9c8630bee2 100644 --- a/genesis/engine/solvers/rigid/collider/collider.py +++ b/genesis/engine/solvers/rigid/collider/collider.py @@ -914,6 +914,18 @@ def detection(self) -> None: self._collider_static_config, ) + # Plane-convex contacts use analytic paths that don't fill + # `diff_contact_input`; populate it here so the differentiable + # narrow-phase reverse can reconstruct them (see + # `kernel_fill_diff_contact_input_plane`). + if self._solver._static_rigid_sim_config.requires_grad: + narrowphase.kernel_fill_diff_contact_input_plane( + self._solver.geoms_state, + self._solver.geoms_info, + self._solver._static_rigid_sim_config, + self._collider_state, + ) + def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch_dim: bool = False): # Early return if already pre-computed contact_data = self._contact_data_cache.setdefault((as_tensor, to_torch), {}) @@ -1049,8 +1061,9 @@ def get_contacts(self, as_tensor: bool = True, to_torch: bool = True, keep_batch def backward(self, dL_dposition, dL_dnormal, dL_dpenetration): func_set_upstream_grad(dL_dposition, dL_dnormal, dL_dpenetration, self._collider_state) + self.backward_narrowphase() - # Compute gradient + def backward_narrowphase(self): func_narrow_phase_diff_convex_vs_convex.grad( self._solver.geoms_state, self._solver.geoms_info, diff --git a/genesis/engine/solvers/rigid/collider/diff_gjk.py b/genesis/engine/solvers/rigid/collider/diff_gjk.py index 336a721eff..4e0726ac8c 100644 --- a/genesis/engine/solvers/rigid/collider/diff_gjk.py +++ b/genesis/engine/solvers/rigid/collider/diff_gjk.py @@ -862,6 +862,57 @@ def func_differentiable_contact( return contact_pos, contact_normal, penetration, weight +@qd.func +def func_differentiable_plane_contact( + geoms_state: array_class.GeomsState, + geoms_info: array_class.GeomsInfo, + diff_contact_input: array_class.DiffContactInput, + i_ga, + i_gb, + i_b, + i_c, +): + """Differentiable plane-convex contact reconstruction. + + Mirrors the analytic plane branch of `func_convex_convex_contact`: + normal = -normalize(R(quat_plane) @ plane_local_dir) + v_world = R(quat_convex) @ core_local + pos_convex + radius * normal + penetration = normal . (v_world - pos_plane) + contact_pos = v_world - 0.5 * penetration * normal + + `i_ga` is the PLANE geom, `i_gb` the convex geom. `core_local` (box vertex / + sphere center / capsule nearest endpoint, in convex-local frame) is the + stored non-diff witness; `radius` and `plane_local_dir` come from geoms_info. + Gradients flow to both geom poses through `geoms_state.{pos,quat}`. For a + sphere `core_local` is the local origin, so the orientation gradient is + naturally zero (the contact is rotation-invariant), matching the forward. + """ + trans_plane = geoms_state.pos[i_ga, i_b] + quat_plane = geoms_state.quat[i_ga, i_b] + trans_convex = geoms_state.pos[i_gb, i_b] + quat_convex = geoms_state.quat[i_gb, i_b] + + plane_dir = gs.qd_vec3(geoms_info.data[i_ga][0], geoms_info.data[i_ga][1], geoms_info.data[i_ga][2]) + plane_dir = gu.qd_transform_by_quat(plane_dir, quat_plane) + normal = -plane_dir.normalized() + + radius = gs.qd_float(0.0) + convex_type = geoms_info.type[i_gb] + if convex_type == gs.GEOM_TYPE.SPHERE: + radius = geoms_info.data[i_gb][0] + elif convex_type == gs.GEOM_TYPE.CAPSULE: + radius = geoms_info.data[i_gb][0] + + core_local = diff_contact_input.core_local[i_b, i_c] + core_world = gu.qd_transform_by_trans_quat(core_local, trans_convex, quat_convex) + v_world = core_world + radius * normal + + penetration = normal.dot(v_world - trans_plane) + contact_pos = v_world - 0.5 * penetration * normal + weight = gs.qd_float(1.0) + return contact_pos, normal, penetration, weight + + @qd.func def func_plane_normal(v1, v2, v3): """ diff --git a/genesis/engine/solvers/rigid/collider/narrowphase.py b/genesis/engine/solvers/rigid/collider/narrowphase.py index 696dcc1e61..07c3f3d122 100644 --- a/genesis/engine/solvers/rigid/collider/narrowphase.py +++ b/genesis/engine/solvers/rigid/collider/narrowphase.py @@ -2705,9 +2705,18 @@ def func_narrow_phase_diff_convex_vs_convex( if is_ref: ref_penetration = -1.0 - contact_pos, contact_normal, penetration, weight = diff_gjk.func_differentiable_contact( - geoms_state, diff_contact_input, gjk_info, i_ga, i_gb, i_b, i_c, ref_penetration - ) + contact_pos = gs.qd_vec3(0.0, 0.0, 0.0) + contact_normal = gs.qd_vec3(0.0, 0.0, 0.0) + penetration = gs.qd_float(0.0) + weight = gs.qd_float(0.0) + if geoms_info.type[i_ga] == gs.GEOM_TYPE.PLANE: + contact_pos, contact_normal, penetration, weight = diff_gjk.func_differentiable_plane_contact( + geoms_state, geoms_info, diff_contact_input, i_ga, i_gb, i_b, i_c + ) + else: + contact_pos, contact_normal, penetration, weight = diff_gjk.func_differentiable_contact( + geoms_state, diff_contact_input, gjk_info, i_ga, i_gb, i_b, i_c, ref_penetration + ) collider_state.diff_contact_input.ref_penetration[i_b, i_c] = penetration func_set_contact( @@ -2755,6 +2764,60 @@ def func_narrow_phase_diff_convex_vs_convex( ) +@qd.kernel(fastcache=True) +def kernel_fill_diff_contact_input_plane( + geoms_state: array_class.GeomsState, + geoms_info: array_class.GeomsInfo, + static_rigid_sim_config: qd.template(), + collider_state: array_class.ColliderState, +): + """Populate `diff_contact_input` for plane-convex contacts (forward, non-diff). + + The analytic plane paths (`func_plane_box_contact`, + `func_convex_convex_contact`'s plane branch) do NOT fill `diff_contact_input`, + so the differentiable narrow-phase reverse would have nothing to reconstruct. + Both paths use the same convention `contact_pos = v1 - 0.5*pen*normal` with + `normal = -normalize(R(quat_plane) @ plane_local_dir)`, so the convex support + point is recovered as `v1 = contact_pos + 0.5*pen*normal`, and its non-diff + "core" (vertex / sphere center / capsule endpoint) as `v1 - radius*normal`, + stored in the convex's local frame. PLANE is GEOM_TYPE 0 so it is always + `geom_a` after the canonical type-ordered swap. + """ + _B = collider_state.active_buffer.shape[1] + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_c, i_b in qd.ndrange(collider_state.contact_data.pos.shape[0], _B): + if i_c < collider_state.n_contacts[i_b]: + i_ga = collider_state.contact_data.geom_a[i_c, i_b] + i_gb = collider_state.contact_data.geom_b[i_c, i_b] + if geoms_info.type[i_ga] == gs.GEOM_TYPE.PLANE: + quat_plane = geoms_state.quat[i_ga, i_b] + trans_convex = geoms_state.pos[i_gb, i_b] + quat_convex = geoms_state.quat[i_gb, i_b] + + plane_dir = gs.qd_vec3(geoms_info.data[i_ga][0], geoms_info.data[i_ga][1], geoms_info.data[i_ga][2]) + plane_dir = gu.qd_transform_by_quat(plane_dir, quat_plane) + normal = -plane_dir.normalized() + + radius = gs.qd_float(0.0) + ctype = geoms_info.type[i_gb] + if ctype == gs.GEOM_TYPE.SPHERE: + radius = geoms_info.data[i_gb][0] + elif ctype == gs.GEOM_TYPE.CAPSULE: + radius = geoms_info.data[i_gb][0] + + pen = collider_state.contact_data.penetration[i_c, i_b] + cpos = collider_state.contact_data.pos[i_c, i_b] + v1 = cpos + 0.5 * pen * normal + core_world = v1 - radius * normal + core_local = gu.qd_transform_by_quat(core_world - trans_convex, gu.qd_inv_quat(quat_convex)) + + collider_state.diff_contact_input.geom_a[i_b, i_c] = i_ga + collider_state.diff_contact_input.geom_b[i_b, i_c] = i_gb + collider_state.diff_contact_input.core_local[i_b, i_c] = core_local + collider_state.diff_contact_input.ref_id[i_b, i_c] = i_c + collider_state.diff_contact_input.valid[i_b, i_c] = 1 + + @qd.kernel(fastcache=True) def func_narrow_phase_convex_specializations( geoms_state: array_class.GeomsState, diff --git a/genesis/engine/solvers/rigid/constraint/backward.py b/genesis/engine/solvers/rigid/constraint/backward.py index e72cac1146..9e6b3cbc76 100644 --- a/genesis/engine/solvers/rigid/constraint/backward.py +++ b/genesis/engine/solvers/rigid/constraint/backward.py @@ -2,6 +2,7 @@ import genesis as gs import genesis.utils.array_class as array_class +import genesis.utils.geom as gu @qd.func @@ -52,6 +53,63 @@ def func_matvec_Ap( constraint_state.bw_Ap[i_d, i_b] += constraint_state.jac[i_c, i_d, i_b] * jv +@qd.func +def func_solve_adjoint_u_cg_env( + i_b, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """CG solve of A u = g for a single environment `i_b`. + + `A = M + J^T diag(D) J` is applied implicitly by `func_matvec_Ap`, which + reads `rigid_global_info.mass_mat` directly and loops only over the active + constraints. So this also solves the unconstrained case `A = M` (no active + constraint -> empty J term). + """ + n_dofs = constraint_state.bw_u.shape[0] + + # r = g - A*0 = g ; p = r ; u = 0 + for i_d in range(n_dofs): + constraint_state.bw_u[i_d, i_b] = 0.0 + constraint_state.bw_r[i_d, i_b] = constraint_state.dL_dqacc[i_d, i_b] + constraint_state.bw_p[i_d, i_b] = constraint_state.bw_r[i_d, i_b] + + for it in range(rigid_global_info.iterations[None]): + func_matvec_Ap( + entities_info=entities_info, + rigid_global_info=rigid_global_info, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + i_b=i_b, + ) + + # alpha = (r,r)/(p,Ap) + num = gs.qd_float(0.0) + den = gs.qd_float(0.0) + for i_d in range(n_dofs): + num += constraint_state.bw_r[i_d, i_b] * constraint_state.bw_r[i_d, i_b] + den += constraint_state.bw_p[i_d, i_b] * constraint_state.bw_Ap[i_d, i_b] + alpha = num / qd.max(den, rigid_global_info.EPS[None]) + + # u += alpha p ; r -= alpha Ap + for i_d in range(n_dofs): + constraint_state.bw_u[i_d, i_b] += alpha * constraint_state.bw_p[i_d, i_b] + constraint_state.bw_r[i_d, i_b] -= alpha * constraint_state.bw_Ap[i_d, i_b] + + if num < rigid_global_info.EPS[None]: + break + + # beta = (r_new,r_new)/(r_old,r_old) ; p = r + beta p + num_new = gs.qd_float(0.0) + for i_d in range(n_dofs): + num_new += constraint_state.bw_r[i_d, i_b] * constraint_state.bw_r[i_d, i_b] + beta = num_new / qd.max(num, rigid_global_info.EPS[None]) + for i_d in range(n_dofs): + constraint_state.bw_p[i_d, i_b] = constraint_state.bw_r[i_d, i_b] + beta * constraint_state.bw_p[i_d, i_b] + + @qd.kernel def kernel_solve_adjoint_u( entities_info: array_class.EntitiesInfo, @@ -77,76 +135,47 @@ def kernel_solve_adjoint_u( constraint_state.bw_u[i_d, i_b] = 0.0 if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): - # Since we already have the Cholesky decomposition of A (= L * L^T), we can use it to solve A * u = g. for i_b in range(_B): - # z = L^{-1} g (forward substitution) - # Save solution to bw_r - for i_d in range(n_dofs): - z = constraint_state.dL_dqacc[i_d, i_b] - for j_d in range(i_d): - z -= constraint_state.nt_H[i_b, i_d, j_d] * constraint_state.bw_r[j_d, i_b] - z /= constraint_state.nt_H[i_b, i_d, i_d] - constraint_state.bw_r[i_d, i_b] = z - - # u = L^{-T} z (back substitution) - for i_d_ in range(n_dofs): - i_d = n_dofs - 1 - i_d_ - u = constraint_state.bw_r[i_d, i_b] - for j_d in range(i_d + 1, n_dofs): - u -= constraint_state.nt_H[i_b, j_d, i_d] * constraint_state.bw_u[j_d, i_b] - u /= constraint_state.nt_H[i_b, i_d, i_d] - constraint_state.bw_u[i_d, i_b] = u - else: - # Use CG solver for solving A * u = g. - # 2. Local buffers for solving A * u = g - # Initialize r, p with dL_dqacc - for i_d, i_b in qd.ndrange(n_dofs, _B): - # Residual: g - A * 0 (u = 0) - constraint_state.bw_r[i_d, i_b] = constraint_state.dL_dqacc[i_d, i_b] - # Search direction: p = r - constraint_state.bw_p[i_d, i_b] = constraint_state.bw_r[i_d, i_b] - - # 3. Solve A * u = g, parallelized over batch dimension - for i_b in range(_B): - # Compute Ap for the current search direction - for it in range(static_rigid_sim_config.iterations): - func_matvec_Ap( + if constraint_state.n_constraints[i_b] == 0: + # No active constraint: A = M. The forward's constrained-Hessian + # Cholesky `nt_H` is unreliable for these envs (the GPU tiled + # factorization skips n_c==0), so solve M u = g via CG, which + # reads `mass_mat` directly and never touches `nt_H`. + func_solve_adjoint_u_cg_env( + i_b=i_b, entities_info=entities_info, rigid_global_info=rigid_global_info, constraint_state=constraint_state, static_rigid_sim_config=static_rigid_sim_config, - i_b=i_b, ) - - # alpha = (r,r)/(p,Hp) - num = gs.qd_float(0.0) - den = gs.qd_float(0.0) - for i_d in range(n_dofs): - num += constraint_state.bw_r[i_d, i_b] * constraint_state.bw_r[i_d, i_b] - den += constraint_state.bw_p[i_d, i_b] * constraint_state.bw_Ap[i_d, i_b] - alpha = num / qd.max(den, rigid_global_info.EPS[None]) - - # u += alpha p ; r -= alpha Hp - for i_d in range(n_dofs): - constraint_state.bw_u[i_d, i_b] += alpha * constraint_state.bw_p[i_d, i_b] - constraint_state.bw_r[i_d, i_b] -= alpha * constraint_state.bw_Ap[i_d, i_b] - - # check tol (optional: per-batch) - # TODO: Might need lower tolerance? - if num < rigid_global_info.EPS[None]: - break - - # beta = (r_new,r_new)/(r_old,r_old) - num_new = gs.qd_float(0.0) + else: + # Reuse the forward's Cholesky decomposition A = L * L^T to solve A u = g. + # z = L^{-1} g (forward substitution); saved to bw_r for i_d in range(n_dofs): - num_new += constraint_state.bw_r[i_d, i_b] * constraint_state.bw_r[i_d, i_b] - beta = num_new / qd.max(num, rigid_global_info.EPS[None]) + z = constraint_state.dL_dqacc[i_d, i_b] + for j_d in range(i_d): + z -= constraint_state.nt_H[i_b, i_d, j_d] * constraint_state.bw_r[j_d, i_b] + z /= constraint_state.nt_H[i_b, i_d, i_d] + constraint_state.bw_r[i_d, i_b] = z - # p = r + beta p - for i_d in range(n_dofs): - constraint_state.bw_p[i_d, i_b] = ( - constraint_state.bw_r[i_d, i_b] + beta * constraint_state.bw_p[i_d, i_b] - ) + # u = L^{-T} z (back substitution) + for i_d_ in range(n_dofs): + i_d = n_dofs - 1 - i_d_ + u = constraint_state.bw_r[i_d, i_b] + for j_d in range(i_d + 1, n_dofs): + u -= constraint_state.nt_H[i_b, j_d, i_d] * constraint_state.bw_u[j_d, i_b] + u /= constraint_state.nt_H[i_b, i_d, i_d] + constraint_state.bw_u[i_d, i_b] = u + else: + # CG solver for A * u = g (parallelized over the batch dimension). + for i_b in range(_B): + func_solve_adjoint_u_cg_env( + i_b=i_b, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) @qd.kernel @@ -258,3 +287,437 @@ def kernel_compute_gradients( val0 = -constraint_state.bw_u[i, i_b] * constraint_state.qacc[j, i_b] val1 = -constraint_state.bw_u[j, i_b] * constraint_state.qacc[i, i_b] constraint_state.dL_dM[i, j, i_b] += (val0 + val1) * 0.5 # symmetrize + + +@qd.kernel(fastcache=True) +def kernel_load_dL_dqacc_from_acc_grad( + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Copy `dofs_state.acc.grad` into `constraint_state.dL_dqacc` (input buffer + consumed by `kernel_solve_adjoint_u`) and zero the source grad so the + downstream IFT-through-M path does not re-consume it. + """ + _B = dofs_state.acc.shape[1] + n_dofs = dofs_state.acc.shape[0] + qd.loop_config( + name="kernel_load_dL_dqacc_from_acc_grad", + serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL), + ) + for i_d, i_b in qd.ndrange(n_dofs, _B): + constraint_state.dL_dqacc[i_d, i_b] = dofs_state.acc.grad[i_d, i_b] + dofs_state.acc.grad[i_d, i_b] = gs.qd_float(0.0) + + +@qd.kernel(fastcache=True) +def kernel_accumulate_constraint_solver_grads( + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + """Fold the constraint-solver adjoint outputs into the autograd grad fields: + dofs_state.force.grad += constraint_state.dL_dforce + rigid_global_info.mass_mat.grad += constraint_state.dL_dM + """ + _B = dofs_state.force.shape[1] + n_dofs = dofs_state.force.shape[0] + qd.loop_config( + name="kernel_accumulate_constraint_solver_grads", + serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL), + ) + for i_d, i_b in qd.ndrange(n_dofs, _B): + dofs_state.force.grad[i_d, i_b] += constraint_state.dL_dforce[i_d, i_b] + for i, j, i_b in qd.ndrange(n_dofs, n_dofs, _B): + rigid_global_info.mass_mat.grad[i, j, i_b] += constraint_state.dL_dM[i, j, i_b] + + +# --------------------------------------------------------------------------- +# Manual reverses of the constraint-force inequality constraints (collision, +# joint-limit). Shared conventions for the two kernels below. +# +# Why manual (not autograd): the constraint rows are built inside the forward +# solver with a data-dependent count and ordering -- `n_con` is assigned by +# atomic_add as active constraints are discovered -- which autograd cannot +# differentiate cleanly (the row index is not a static, taped quantity). +# +# Upstream grads: `kernel_compute_gradients` populates, per constraint row +# `n_con`, `constraint_state.dL_daref[n_con]` (dL/d aref), `dL_defc_D[n_con]` +# (dL/d efc_D), and `dL_djac[n_con, i_d]` (dL/d jac). The collision reverse uses +# `dL_djac`; the joint-limit reverse ignores it (its jac entries are piecewise- +# constant +-1, so the sub-gradient is 0). Each kernel consumes these and +# accumulates into its own differentiable inputs. +# +# `n_con` row layout: the forward adds constraints in the order frictionloss -> +# collision -> joint-limit (see `add_inequality_constraints`). So with collision +# on, the collision group occupies rows [0, 4 * n_contacts) (4 friction-pyramid +# rows per contact) and joint-limit rows follow it. Each reverse re-walks the +# same forward loop deterministically to recover its own `n_con` (no atomic_add, +# no `n_constraints` reset): collision uses `n_con = i_col * 4 + i`; joint-limit +# seeds its counter at `4 * n_contacts`. +# +# TODO: only collision + joint-limit are handled. If other constraint groups +# (equality, frictionloss) are ever added to a differentiable scene, the `n_con` +# offset in each reverse must be updated to account for them -- they are +# currently assumed absent (not offset). +# --------------------------------------------------------------------------- +@qd.kernel(fastcache=True) +def kernel_manual_add_joint_limit_constraints_bw( + links_info: array_class.LinksInfo, + joints_info: array_class.JointsInfo, + dofs_info: array_class.DofsInfo, + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, + constraint_state: array_class.ConstraintState, + collider_state: array_class.ColliderState, + static_rigid_sim_config: qd.template(), + enable_collision: qd.template(), +): + """Manual reverse of `add_joint_limit_constraints`. See the section header + above for the shared `n_con` layout and upstream-grad conventions. + + Accumulates into rigid_global_info.qpos.grad[i_q] and dofs_state.vel.grad[i_d]. + + Chain rule (per active joint, `pos_delta < 0`): + + Forward: + pos_delta_min = qpos[i_q] - limit_lo + pos_delta_max = limit_hi - qpos[i_q] + pos_delta = min(pos_delta_min, pos_delta_max) + sign = +1 if pos_delta_min < pos_delta_max else -1 + jac_qvel = sign * dofs_vel[i_d] + imp, aref = gu.imp_aref(sol_params, pos_delta, jac_qvel, pos_delta) + diag_raw = invweight * (1 - imp) / imp + diag = max(diag_raw, EPS) + efc_D = 1 / diag + + d(pos_delta) / d(qpos) = sign (chosen branch of `min`) + d(jac_qvel) / d(vel) = sign + + dL/d(imp) = ga * d(aref)/d(imp) + gD * d(efc_D)/d(imp) + ga = dL_daref[n_con], gD = dL_defc_D[n_con] + + dL/d(pos_delta) = ga * d(aref)/d(pos_delta)|_direct + + dL/d(imp) * d(imp)/d(imp_x) * d(imp_x)/d(pos_delta) + + dL/d(jac_qvel) = ga * d(aref)/d(jac_qvel) = -ga * b_coef + + dL/d(qpos) += sign * dL/d(pos_delta) + dL/d(vel) += sign * dL/d(jac_qvel) + """ + EPS = rigid_global_info.EPS[None] + _B = constraint_state.jac.shape[2] + n_links = links_info.root_idx.shape[0] + + qd.loop_config( + name="kernel_manual_add_joint_limit_constraints_bw", + serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL), + ) + for i_b in range(_B): + # Collision constraints (4 rows per contact) are added before joint + # limits in `add_inequality_constraints`, so offset the joint-limit row + # counter past them when collision is on. + n_con_counter = gs.qd_int(0) + if qd.static(enable_collision): + n_con_counter = gs.qd_int(collider_state.n_contacts[i_b] * 4) + + for i_l in range(n_links): + I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l + for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): + I_j = [i_j, i_b] if qd.static(static_rigid_sim_config.batch_joints_info) else i_j + + if joints_info.type[I_j] == gs.JOINT_TYPE.REVOLUTE or joints_info.type[I_j] == gs.JOINT_TYPE.PRISMATIC: + i_q = joints_info.q_start[I_j] + i_d = joints_info.dof_start[I_j] + I_d = [i_d, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d + + pos_delta_min = rigid_global_info.qpos[i_q, i_b] - dofs_info.limit[I_d][0] + pos_delta_max = dofs_info.limit[I_d][1] - rigid_global_info.qpos[i_q, i_b] + pos_delta = qd.min(pos_delta_min, pos_delta_max) + + if pos_delta < 0: + n_con = n_con_counter + n_con_counter = n_con_counter + 1 + + # Replay forward intermediates (cheap, avoids stashing). + sign_pos = (pos_delta_min < pos_delta_max) * 2 - 1 + sign_f = gs.qd_float(sign_pos) + + sol_params = joints_info.sol_params[I_j] + timeconst = sol_params[0] + dampratio = sol_params[1] + dmin = sol_params[2] + dmax = sol_params[3] + width = sol_params[4] + mid = sol_params[5] + power = sol_params[6] + + imp_x = qd.abs(pos_delta) / width + imp_a_coef = 1.0 / mid ** (power - 1.0) + imp_b_coef = 1.0 / (1.0 - mid) ** (power - 1.0) + imp_a = imp_a_coef * imp_x**power + imp_b = 1.0 - imp_b_coef * (1.0 - imp_x) ** power + imp_y = imp_a if imp_x < mid else imp_b + imp_raw = dmin + imp_y * (dmax - dmin) + imp_clamped = qd.math.clamp(imp_raw, dmin, dmax) + imp = dmax if imp_x > 1.0 else imp_clamped + + b_coef = 2.0 / (dmax * timeconst) + k_coef = 1.0 / (dmax * dmax * timeconst * timeconst * dampratio * dampratio) + + invweight = dofs_info.invweight[I_d] + diag_raw = invweight * (1.0 - imp) / imp + diag = qd.max(diag_raw, EPS) + + # Upstream grads. + ga = constraint_state.dL_daref[n_con, i_b] + gD = constraint_state.dL_defc_D[n_con, i_b] + + # --- Partials of forward outputs w.r.t. intermediates --- + # aref = -b_coef * jac_qvel - k_coef * imp * pos_delta + d_aref_d_imp = -k_coef * pos_delta + d_aref_d_jac_qvel = -b_coef + d_aref_d_pos_delta_direct = -k_coef * imp + + # diag_raw = invweight*(1-imp)/imp ⇒ d(diag_raw)/d(imp) = -invweight/imp^2 + # diag = max(diag_raw, EPS); efc_D = 1/diag + # d(efc_D)/d(imp) = -1/diag^2 · d(diag)/d(imp), 0 if clamped to EPS + d_diag_d_imp = gs.qd_float(0.0) + if diag_raw > EPS: + d_diag_d_imp = -invweight / (imp * imp) + d_efc_D_d_imp = -d_diag_d_imp / (diag * diag) + + # d(imp)/d(imp_x): active only inside the smooth clamp band. + within_clamp = (imp_raw > dmin) and (imp_raw < dmax) and (imp_x <= 1.0) + d_imp_y_d_imp_x = gs.qd_float(0.0) + if imp_x < mid: + d_imp_y_d_imp_x = power * imp_a_coef * imp_x ** (power - 1.0) + else: + d_imp_y_d_imp_x = power * imp_b_coef * (1.0 - imp_x) ** (power - 1.0) + d_imp_d_imp_x = gs.qd_float(0.0) + if within_clamp: + d_imp_d_imp_x = (dmax - dmin) * d_imp_y_d_imp_x + + # d(imp_x)/d(pos_delta) = sign(pos_delta)/width; pos_delta < 0 ⇒ -1/width + d_imp_x_d_pos_delta = -1.0 / width + d_imp_d_pos_delta = d_imp_d_imp_x * d_imp_x_d_pos_delta + + # --- Combine --- + dL_d_imp = ga * d_aref_d_imp + gD * d_efc_D_d_imp + dL_d_pos_delta = ga * d_aref_d_pos_delta_direct + dL_d_imp * d_imp_d_pos_delta + dL_d_jac_qvel = ga * d_aref_d_jac_qvel + + # --- Propagate --- + rigid_global_info.qpos.grad[i_q, i_b] += sign_f * dL_d_pos_delta + dofs_state.vel.grad[i_d, i_b] += sign_f * dL_d_jac_qvel + + +@qd.kernel(fastcache=True) +def kernel_manual_add_collision_constraints_bw( + links_info: array_class.LinksInfo, + links_state: array_class.LinksState, + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + collider_state: array_class.ColliderState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Manual reverse of `add_collision_constraints`. See the section header + above for the shared `n_con` layout and upstream-grad conventions. + + Produces the gradients w.r.t. the collision constraint's differentiable inputs: + collider_state.contact_data.{pos, normal, penetration}.grad (-> collider.backward) + dofs_state.{cdof_ang, cdof_vel, vel}.grad + links_state.root_COM.grad + (cdof / root_COM / vel grads feed the COM / forward-dynamics reverse chain; + contact_data grads feed `collider.backward`.) + + Forward recap (per contact `i_col`, per friction-pyramid row `i` in 0..3): + d1, d2 = qd_orthogonals(normal); d = s_i * (d1 if i<2 else d2), s_i = 2*(i%2)-1 + n = d * friction - normal + jac[n_con, i_d] = sum_chain (sign * vel_motion(i_d)) . n + vel_motion = cdof_vel - t_pos x cdof_ang, t_pos = contact_pos - root_COM[link] + jac_qvel = sum_chain jac[n_con, i_d] * dofs_vel[i_d] + imp, aref = imp_aref(sol_params, -penetration, jac_qvel, -penetration) + diag = (invweight + friction^2 invweight) * 2 friction^2 (1-imp)/imp ; efc_D = 1/diag + """ + EPS = rigid_global_info.EPS[None] + _B = dofs_state.ctrl_mode.shape[1] + n_dofs = dofs_state.ctrl_mode.shape[0] + max_contact_pairs = collider_state.contact_data.link_a.shape[0] + + qd.loop_config( + name="kernel_manual_add_collision_constraints_bw", + serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL), + ) + for flat_idx in range(max_contact_pairs * _B): + i_b = flat_idx % _B + i_col = flat_idx // _B + if i_col < collider_state.n_contacts[i_b]: + link_a = collider_state.contact_data.link_a[i_col, i_b] + link_b = collider_state.contact_data.link_b[i_col, i_b] + contact_pos = collider_state.contact_data.pos[i_col, i_b] + normal = collider_state.contact_data.normal[i_col, i_b] + friction = collider_state.contact_data.friction[i_col, i_b] + sol_params = collider_state.contact_data.sol_params[i_col, i_b] + penetration = collider_state.contact_data.penetration[i_col, i_b] + + link_a_maybe_batch = [link_a, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else link_a + invweight = links_info.invweight[link_a_maybe_batch][0] + if link_b > -1: + link_b_maybe_batch = [link_b, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else link_b + invweight = invweight + links_info.invweight[link_b_maybe_batch][0] + + # --- forward intermediates of qd_orthogonals(normal) --- + # b_raw branches on |normal[1]| < 0.5; b = normalize(b_raw) + # d1 = b x normal, d2 = b + n0, n1, n2 = normal[0], normal[1], normal[2] + branch_a = qd.abs(n1) < 0.5 + b_raw = gs.qd_vec3(0.0, 0.0, 0.0) + if branch_a: + b_raw = gs.qd_vec3(-n0 * n1, 1.0 - n1 * n1, -n2 * n1) + else: + b_raw = gs.qd_vec3(-n0 * n2, -n1 * n2, 1.0 - n2 * n2) + b_raw_norm = b_raw.norm() + b = b_raw / b_raw_norm + d1 = b.cross(normal) + d2 = b + + sol_timeconst = sol_params[0] + sol_dampratio = sol_params[1] + sol_dmin = sol_params[2] + sol_dmax = sol_params[3] + sol_width = sol_params[4] + sol_mid = sol_params[5] + sol_power = sol_params[6] + + neg_pen = -penetration + imp_x = qd.abs(neg_pen) / sol_width + # d(imp_x)/d(penetration) = -sign(neg_pen)/width + sign_neg = gs.qd_float(1.0) if neg_pen >= 0 else gs.qd_float(-1.0) + d_imp_x_d_pen = -sign_neg / sol_width + + imp_a_coef = 1.0 / sol_mid ** (sol_power - 1.0) + imp_b_coef = 1.0 / (1.0 - sol_mid) ** (sol_power - 1.0) + imp_a = imp_a_coef * imp_x**sol_power + imp_b = 1.0 - imp_b_coef * (1.0 - imp_x) ** sol_power + imp_y = imp_a if imp_x < sol_mid else imp_b + imp_raw = sol_dmin + imp_y * (sol_dmax - sol_dmin) + imp_clamped = qd.math.clamp(imp_raw, sol_dmin, sol_dmax) + imp = sol_dmax if imp_x > 1.0 else imp_clamped + + b_coef = 2.0 / (sol_dmax * sol_timeconst) + # k_coef matches gu.imp_aref's k = 1/(dmax^2 timeconst^2 dampratio^2) + k_coef = 1.0 / (sol_dmax * sol_dmax * sol_timeconst * sol_timeconst * sol_dampratio * sol_dampratio) + + # diag = C0 * (1-imp)/imp, C0 = 2 friction^2 invweight (1 + friction^2) + C0 = (invweight + friction * friction * invweight) * 2.0 * friction * friction + diag_raw = C0 * (1.0 - imp) / imp + diag = qd.max(diag_raw, EPS) + + within_clamp = (imp_raw > sol_dmin) and (imp_raw < sol_dmax) and (imp_x <= 1.0) + d_imp_y_d_imp_x = gs.qd_float(0.0) + if imp_x < sol_mid: + d_imp_y_d_imp_x = sol_power * imp_a_coef * imp_x ** (sol_power - 1.0) + else: + d_imp_y_d_imp_x = sol_power * imp_b_coef * (1.0 - imp_x) ** (sol_power - 1.0) + d_imp_d_imp_x = gs.qd_float(0.0) + if within_clamp: + d_imp_d_imp_x = (sol_dmax - sol_dmin) * d_imp_y_d_imp_x + + d_diag_d_imp = gs.qd_float(0.0) + if diag_raw > EPS: + d_diag_d_imp = -C0 / (imp * imp) + d_efc_D_d_imp = -d_diag_d_imp / (diag * diag) + + # Accumulators for this contact's differentiable inputs. + g_pos = gs.qd_vec3(0.0, 0.0, 0.0) + g_normal = gs.qd_vec3(0.0, 0.0, 0.0) + g_pen = gs.qd_float(0.0) + g_d1 = gs.qd_vec3(0.0, 0.0, 0.0) + g_d2 = gs.qd_vec3(0.0, 0.0, 0.0) + + for i in range(4): + s_i = gs.qd_float(2 * (i % 2) - 1) + d = s_i * d1 if i < 2 else s_i * d2 + n = d * friction - normal + n_con = i_col * 4 + i + + ga = constraint_state.dL_daref[n_con, i_b] + gD = constraint_state.dL_defc_D[n_con, i_b] + + # aref = -b_coef*jac_qvel + k_coef*imp*penetration (pos arg = -penetration) + d_aref_d_imp = k_coef * penetration + d_aref_d_pen_direct = k_coef * imp + d_aref_d_jac_qvel = -b_coef + + dL_d_imp = ga * d_aref_d_imp + gD * d_efc_D_d_imp + dL_d_pen = ga * d_aref_d_pen_direct + dL_d_imp * d_imp_d_imp_x * d_imp_x_d_pen + g_pen += dL_d_pen + dL_d_jac_qvel = ga * d_aref_d_jac_qvel + + # Reverse jac[n_con, i_d] over the kinematic chain. + dL_dn = gs.qd_vec3(0.0, 0.0, 0.0) + for i_ab in range(2): + sign = gs.qd_float(-1.0) + link = link_a + if i_ab == 1: + sign = gs.qd_float(1.0) + link = link_b + while link > -1: + link_mb = [link, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else link + for i_d_ in range(links_info.n_dofs[link_mb]): + i_d = links_info.dof_end[link_mb] - 1 - i_d_ + + cdof_ang = dofs_state.cdof_ang[i_d, i_b] + cdof_vel = dofs_state.cdof_vel[i_d, i_b] + t_pos = contact_pos - links_state.root_COM[link, i_b] + vel_motion = cdof_vel - t_pos.cross(cdof_ang) + + jac_stored = constraint_state.jac[n_con, i_d, i_b] + g_jac = constraint_state.dL_djac[n_con, i_d, i_b] + dL_d_jac_qvel * dofs_state.vel[i_d, i_b] + dofs_state.vel.grad[i_d, i_b] += dL_d_jac_qvel * jac_stored + + # jac_contrib = (sign * vel_motion) . n + dL_dn += g_jac * sign * vel_motion + g_vm = g_jac * sign * n # dL/d(vel_motion) + + # vel_motion = cdof_vel - t_pos x cdof_ang + dofs_state.cdof_vel.grad[i_d, i_b] += g_vm + dofs_state.cdof_ang.grad[i_d, i_b] += t_pos.cross(g_vm) + dt = -(cdof_ang.cross(g_vm)) # dL/d(t_pos) + g_pos += dt + links_state.root_COM.grad[link, i_b] += -dt + + link = links_info.parent_idx[link_mb] + + # n = d*friction - normal + g_normal += -dL_dn + g_d = dL_dn * friction + if i < 2: + g_d1 += s_i * g_d + else: + g_d2 += s_i * g_d + + # Reverse qd_orthogonals: d1 = b x normal, d2 = b, b = normalize(b_raw(normal)). + dL_db = g_d2 + normal.cross(g_d1) + g_normal += g_d1.cross(b) + # b = b_raw / |b_raw| + dL_db_raw = (dL_db - dL_db.dot(b) * b) / b_raw_norm + # b_raw(normal) branch Jacobian + if branch_a: + # b_raw = (-n0 n1, 1 - n1^2, -n2 n1) + g_normal[0] += dL_db_raw[0] * (-n1) + g_normal[1] += dL_db_raw[0] * (-n0) + dL_db_raw[1] * (-2.0 * n1) + dL_db_raw[2] * (-n2) + g_normal[2] += dL_db_raw[2] * (-n1) + else: + # b_raw = (-n0 n2, -n1 n2, 1 - n2^2) + g_normal[0] += dL_db_raw[0] * (-n2) + g_normal[1] += dL_db_raw[1] * (-n2) + g_normal[2] += dL_db_raw[0] * (-n0) + dL_db_raw[1] * (-n1) + dL_db_raw[2] * (-2.0 * n2) + + for j in qd.static(range(3)): + collider_state.contact_data.pos.grad[i_col, i_b][j] = g_pos[j] + collider_state.contact_data.normal.grad[i_col, i_b][j] = g_normal[j] + collider_state.contact_data.penetration.grad[i_col, i_b] = g_pen diff --git a/genesis/engine/solvers/rigid/constraint/noslip.py b/genesis/engine/solvers/rigid/constraint/noslip.py index cf4c465a93..7df5fa94b0 100644 --- a/genesis/engine/solvers/rigid/constraint/noslip.py +++ b/genesis/engine/solvers/rigid/constraint/noslip.py @@ -44,11 +44,9 @@ def kernel_build_efc_AR_b( i_b, constraint_state.Mgrad, constraint_state.Mgrad, - None, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=False, ) # TODO: For consistency with other usages, migrate to either the lower or upper variant @@ -331,11 +329,9 @@ def kernel_dual_finish( i_b=i_b, vec=constraint_state.qfrc_constraint, out=constraint_state.qacc, - out_bw=None, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=False, ) for i_d in range(n_dofs): @@ -426,11 +422,9 @@ def compute_A_diag( i_b, constraint_state.Mgrad, constraint_state.Mgrad, - None, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=False, ) # Ai = Ji * tmp diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 0283f7c206..c1397b7653 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -442,12 +442,13 @@ def delete_weld_constraint(self, link1_idx, link2_idx, envs_idx=None): self._solver._static_rigid_sim_config, ) - def backward(self, dL_dqacc): + def backward(self): if not self._solver._requires_grad: gs.raise_exception("Please set `requires_grad` to True in SimOptions to enable differentiable mode.") - # Copy upstream gradients - self.constraint_state.dL_dqacc.from_numpy(dL_dqacc) + # Upstream gradient `dL_dqacc` is expected to be pre-populated in + # `constraint_state.dL_dqacc` by the caller (see + # `kernel_load_dL_dqacc_from_acc_grad`). # 1. We first need to find a solution to A^T * u = g system. backward_constraint_solver.kernel_solve_adjoint_u( @@ -3522,11 +3523,9 @@ def func_update_gradient_batch( i_b, constraint_state.grad, constraint_state.Mgrad, - None, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=False, ) if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): @@ -3569,11 +3568,9 @@ def func_update_gradient_tiled( i_b, constraint_state.grad, constraint_state.Mgrad, - None, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, - is_backward=False, ) if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): diff --git a/genesis/engine/solvers/rigid/constraint/solver_island.py b/genesis/engine/solvers/rigid/constraint/solver_island.py index 3a300cda20..b18a4ff833 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_island.py +++ b/genesis/engine/solvers/rigid/constraint/solver_island.py @@ -987,11 +987,9 @@ def _func_update_gradient( i_b, self.grad, self.Mgrad, - None, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=self._solver._static_rigid_sim_config, - is_backward=False, ) for i_e in range(self._solver.n_entities): self._solver.mass_mat_mask[i_e, i_b] = True diff --git a/genesis/engine/solvers/rigid/rigid_solver.py b/genesis/engine/solvers/rigid/rigid_solver.py index cf61c94ef2..85f700de74 100644 --- a/genesis/engine/solvers/rigid/rigid_solver.py +++ b/genesis/engine/solvers/rigid/rigid_solver.py @@ -29,6 +29,12 @@ from ..kinematic_solver import KinematicSolver from .collider import Collider from .constraint import ConstraintSolver, ConstraintSolverIsland +from .constraint.backward import ( + kernel_manual_add_joint_limit_constraints_bw, + kernel_manual_add_collision_constraints_bw, + kernel_load_dL_dqacc_from_acc_grad, + kernel_accumulate_constraint_solver_grads, +) from .abd.misc import ( func_add_safe_backward, func_apply_coupling_force, @@ -91,7 +97,10 @@ kernel_update_all_verts, kernel_update_geom_aabbs, kernel_update_vgeoms, + kernel_COM_links_replay, kernel_update_cartesian_space, + kernel_forward_kinematics_replay, + kernel_update_geoms_replay, ) from .abd.forward_dynamics import ( func_actuation, @@ -169,6 +178,12 @@ kernel_prepare_backward_substep, kernel_begin_backward_substep, kernel_copy_acc, + kernel_copy_next_to_curr_no_check, +) +from .abd.manual_bw import ( + kernel_manual_compute_qacc_bw, + kernel_manual_forward_kinematics_bw, + kernel_manual_forward_velocity_bw, ) if TYPE_CHECKING: @@ -1027,7 +1042,6 @@ def substep(self, f): self.dofs_state, self._rigid_global_info, self._static_rigid_sim_config, - self._is_backward, ) else: self._func_constraint_force() @@ -1089,6 +1103,12 @@ def check_errno(self): gs.raise_exception("Invalid accelerations causing 'nan'. Please decrease Rigid simulation timestep.") if errno & array_class.ErrorCode.OVERFLOW_HIBERNATION_ISLANDS: gs.raise_exception("Contact island buffer overflow. Please increase RigidOptions 'max_collision_pairs'.") + if errno & array_class.ErrorCode.MANUAL_BW_UNIMPLEMENTED: + gs.raise_exception( + "Encountered a configuration (e.g. hibernation) that the manual backward kernels " + "do not support. Extend the corresponding `kernel_manual_*_bw` in " + "`genesis/engine/solvers/rigid/abd/manual_bw.py`." + ) def _kernel_detect_collision(self): self.collider.clear() @@ -1302,6 +1322,105 @@ def reset_grad(self): qd_zero_grad(self.geoms_state_adjoint_cache) qd_zero_grad(self._rigid_adjoint_cache) + def _update_cartesian_grad(self, envs_idx): + """Forward-replay the post-integrate cartesian-space update + (FK -> COM -> geom poses -> velocity) under `is_backward=True`, then + reverse it: velocity and FK are reversed manually (`kernel_manual_*_bw`), + while COM and the link->geom transform are reversed by Quadrants autograd + (`.grad`). Shared by both the post-FK reverse and the first-substep + initial-state reverse in `substep_pre_coupling_grad`. + """ + # Forward replay in dependency order (FK -> COM -> geoms -> velocity). + kernel_forward_kinematics_replay( + envs_idx=envs_idx, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=True, + ) + kernel_COM_links_replay( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=True, + ) + kernel_update_geoms_replay( + entities_info=self.entities_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + links_state=self.links_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=True, + ) + kernel_forward_velocity( + envs_idx=envs_idx, + links_state=self.links_state, + links_info=self.links_info, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=True, + ) + + # Reverse in opposite order (vel.bw -> COM.bw -> geoms.bw -> FK.bw). + kernel_manual_forward_velocity_bw( + links_state=self.links_state, + links_info=self.links_info, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + errno=self._errno, + ) + kernel_COM_links_replay.grad( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=True, + ) + kernel_update_geoms_replay.grad( + entities_info=self.entities_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + links_state=self.links_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=True, + ) + kernel_manual_forward_kinematics_bw( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_info=self.dofs_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + errno=self._errno, + ) + def substep_pre_coupling_grad(self, f): # Change to backward mode self._is_backward = True @@ -1328,36 +1447,17 @@ def substep_pre_coupling_grad(self, f): static_rigid_sim_config=self._static_rigid_sim_config, ) self.substep(f) - # =================== Backward substep ====================== envs_idx = self._scene._sanitize_envs_idx(None) if not self._enable_mujoco_compatibility: - kernel_forward_velocity.grad( - envs_idx=envs_idx, - links_state=self.links_state, - links_info=self.links_info, - joints_info=self.joints_info, + # The FK backward below builds its Jacobian at the post-integrate qpos/vel, so + # copy the integrator's `_next` outputs into the current slots first. + kernel_copy_next_to_curr_no_check( dofs_state=self.dofs_state, - entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=True, - ) - kernel_update_cartesian_space.grad( - links_state=self.links_state, - links_info=self.links_info, - joints_state=self.joints_state, - joints_info=self.joints_info, - dofs_state=self.dofs_state, - dofs_info=self.dofs_info, - geoms_state=self.geoms_state, - geoms_info=self.geoms_info, - entities_info=self.entities_info, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - force_update_fixed_geoms=False, - is_backward=True, ) + self._update_cartesian_grad(envs_idx) is_grad_valid = kernel_begin_backward_substep( f=f, @@ -1382,40 +1482,104 @@ def substep_pre_coupling_grad(self, f): gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}") kernel_step_2.grad( - dofs_state=self.dofs_state, - dofs_info=self.dofs_info, - links_info=self.links_info, - links_state=self.links_state, - joints_info=self.joints_info, - joints_state=self.joints_state, - entities_state=self.entities_state, - entities_info=self.entities_info, - geoms_info=self.geoms_info, - geoms_state=self.geoms_state, - collider_state=self.collider._collider_state, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - contact_island_state=self.constraint_solver.contact_island.contact_island_state, - is_backward=True, - errno=self._errno, + self.dofs_state, + self.dofs_info, + self.links_info, + self.links_state, + self.joints_info, + self.joints_state, + self.entities_state, + self.entities_info, + self.geoms_info, + self.geoms_state, + self.collider._collider_state, + self._rigid_global_info, + self._static_rigid_sim_config, + self.constraint_solver.contact_island.contact_island_state, + self._is_backward, + self._errno, ) - # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel, - # which is prohibited (https://docs.taichi-lang.org/docs/differentiable_programming#global-data-access-rules). - # In [kernel_forward_dynamics], we read [acc] in [func_update_acc] and overwrite it in [kernel_compute_qacc]. - # As [kenrel_compute_qacc] is called at the end of [kernel_forward_dynamics], we first backpropagate through - # [kernel_compute_qacc] and then restore the original [acc] from the adjoint cache. This copy operation - # cannot be merged with [kernel_compute_qacc.grad] because .grad function itself is a standalone kernel. - # We could possibly merge this small kernel later if (1) .grad function is regarded as a function instead of a - # kernel, (2) we add another variable to store the new [acc] from [kernel_compute_qacc] and thus can avoid - # the data access violation. However, both of these require major changes. - kernel_compute_qacc.grad( - dofs_state=self.dofs_state, - entities_info=self.entities_info, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=True, - ) + # Two backward paths for `force -> acc`: + # (A) Unconstrained: `kernel_manual_compute_qacc_bw` does the IFT through M + # (writes force.grad + mass_mat.grad). + # (B) Constrained (active joint limits / collision): the constraint solve + # overwrites acc, so the IFT through M alone is wrong. Instead drive + # `constraint_solver.backward` (adjoint KKT -> dL_dM/djac/daref/defc_D/ + # dforce) + the per-constraint manual reverses below. + _has_collision = (not self._disable_constraint) and self._enable_collision + _has_joint_limit = (not self._disable_constraint) and self._options.enable_joint_limit + _constrained_bw = _has_collision or _has_joint_limit + if _constrained_bw: + # Reject equality / frictionloss constraints: forward orders rows as + # equality -> frictionloss -> collision -> joint-limit, but the + # manual reverses below only re-walk the last two groups. + # TODO: implement manual reverses for equality and frictionloss + # rows so they can participate in differentiable scenes; until then + # we reject host-side instead of producing a wrong gradient. + cs = self.constraint_solver.constraint_state + n_eq_max = int(qd_to_numpy(cs.n_constraints_equality).max()) + n_fric_max = int(qd_to_numpy(cs.n_constraints_frictionloss).max()) + if n_eq_max > 0 or n_fric_max > 0: + gs.raise_exception( + "Differentiable rigid backward does not support equality or frictionloss " + f"constraints (found n_constraints_equality={n_eq_max}, " + f"n_constraints_frictionloss={n_fric_max}). Disable them in a differentiable scene." + ) + + # Backward pass for the constraint solver + kernel_load_dL_dqacc_from_acc_grad( + dofs_state=self.dofs_state, + constraint_state=self.constraint_solver.constraint_state, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + self.constraint_solver.backward() + kernel_accumulate_constraint_solver_grads( + dofs_state=self.dofs_state, + rigid_global_info=self._rigid_global_info, + constraint_state=self.constraint_solver.constraint_state, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + + if _has_collision: + # Manual reverse of the collision constraint rows -> contact-data + # grads, then differentiate the narrow-phase (diff GJK) through + # collider.backward into geom pose grads. + collider_state = self.collider._collider_state + collider_state.contact_data.pos.grad.fill(0.0) + collider_state.contact_data.normal.grad.fill(0.0) + collider_state.contact_data.penetration.grad.fill(0.0) + kernel_manual_add_collision_constraints_bw( + links_info=self.links_info, + links_state=self.links_state, + dofs_state=self.dofs_state, + constraint_state=self.constraint_solver.constraint_state, + collider_state=collider_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + self.collider.backward_narrowphase() + + if _has_joint_limit: + kernel_manual_add_joint_limit_constraints_bw( + links_info=self.links_info, + joints_info=self.joints_info, + dofs_info=self.dofs_info, + dofs_state=self.dofs_state, + rigid_global_info=self._rigid_global_info, + constraint_state=self.constraint_solver.constraint_state, + collider_state=self.collider._collider_state, + static_rigid_sim_config=self._static_rigid_sim_config, + enable_collision=_has_collision, + ) + else: + # Manual backward for `func_compute_qacc` via the Implicit Function Theorem. + kernel_manual_compute_qacc_bw( + dofs_state=self.dofs_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) kernel_copy_acc( f=f, dofs_state=self.dofs_state, @@ -1440,32 +1604,7 @@ def substep_pre_coupling_grad(self, f): # If it was the very first substep, we need to backpropagate through the initial update of the cartesian space if self._enable_mujoco_compatibility or self._sim.cur_substep_global == 0: - kernel_forward_velocity.grad( - envs_idx=envs_idx, - links_state=self.links_state, - links_info=self.links_info, - joints_info=self.joints_info, - dofs_state=self.dofs_state, - entities_info=self.entities_info, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=True, - ) - kernel_update_cartesian_space.grad( - links_state=self.links_state, - links_info=self.links_info, - joints_state=self.joints_state, - joints_info=self.joints_info, - dofs_state=self.dofs_state, - dofs_info=self.dofs_info, - geoms_state=self.geoms_state, - geoms_info=self.geoms_info, - entities_info=self.entities_info, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - force_update_fixed_geoms=False, - is_backward=True, - ) + self._update_cartesian_grad(envs_idx) # Change back to forward mode self._is_backward = False @@ -1481,7 +1620,6 @@ def substep_post_coupling(self, f): dofs_state=self.dofs_state, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, - is_backward=self._is_backward, ) kernel_step_2( dofs_state=self.dofs_state, @@ -1688,9 +1826,11 @@ def save_ckpt(self, ckpt_name): if ckpt_name not in self._ckpt: self._ckpt[ckpt_name] = dict() - self._ckpt[ckpt_name]["qpos"] = qd_to_numpy(self._rigid_adjoint_cache.qpos) - self._ckpt[ckpt_name]["dofs_vel"] = qd_to_numpy(self._rigid_adjoint_cache.dofs_vel) - self._ckpt[ckpt_name]["dofs_acc"] = qd_to_numpy(self._rigid_adjoint_cache.dofs_acc) + # `copy=True` required: with the zerocopy backend `qd_to_numpy` returns a + # view, so later substeps would overwrite this ckpt's buffer in place. + self._ckpt[ckpt_name]["qpos"] = qd_to_numpy(self._rigid_adjoint_cache.qpos, copy=True) + self._ckpt[ckpt_name]["dofs_vel"] = qd_to_numpy(self._rigid_adjoint_cache.dofs_vel, copy=True) + self._ckpt[ckpt_name]["dofs_acc"] = qd_to_numpy(self._rigid_adjoint_cache.dofs_acc, copy=True) for entity in self._entities: entity.save_ckpt(ckpt_name) diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index 5baad880da..564f073ec5 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -86,6 +86,7 @@ class ErrorCode(IntEnum): OVERFLOW_HIBERNATION_ISLANDS = 0b00000000000000000000000000000100 INVALID_FORCE_NAN = 0b00000000000000000000000000001000 INVALID_ACC_NAN = 0b00000000000000000000000000010000 + MANUAL_BW_UNIMPLEMENTED = 0b00000000000000000000000000100000 # =========================================== RigidGlobalInfo =========================================== @@ -108,7 +109,6 @@ class RigidGlobalInfo: geoms_init_AABB: qd.Tensor mass_mat: qd.Tensor mass_mat_L: qd.Tensor - mass_mat_L_bw: qd.Tensor mass_mat_D_inv: qd.Tensor mass_mat_mask: qd.Tensor meaninertia: qd.Tensor @@ -138,11 +138,6 @@ def get_rigid_global_info(solver, kinematic_only): f"Mass matrix shape (n_dofs={solver.n_dofs_}, n_dofs={solver.n_dofs_}, n_envs={_B}) is too large." ) requires_grad = solver._requires_grad - mass_mat_shape_bw = maybe_shape((2, *mass_mat_shape), requires_grad) - if math.prod(mass_mat_shape_bw) > np.iinfo(np.int32).max: - gs.raise_exception( - f"Mass matrix buffer shape (2, n_dofs={solver.n_dofs_}, n_dofs={solver.n_dofs_}, n_envs={_B}) is too large." - ) # Flip mass_mat from canonical (n_dofs(i_d1), n_dofs(i_d2), _B) -> physical (_B, n_dofs(i_d2), n_dofs(i_d1)) via # layout=(2, 1, 0): i_d1 becomes innermost / stride-1, which coalesces consumer kernels whose lanes stride i_d1 @@ -175,7 +170,6 @@ def get_rigid_global_info(solver, kinematic_only): geoms_init_AABB=V_VEC(3, dtype=gs.qd_float, shape=()), mass_mat=V(dtype=gs.qd_float, shape=()), mass_mat_L=V(dtype=gs.qd_float, shape=()), - mass_mat_L_bw=V(dtype=gs.qd_float, shape=()), mass_mat_D_inv=V(dtype=gs.qd_float, shape=()), mass_mat_mask=V(dtype=gs.qd_bool, shape=()), mass_parent_mask=V(dtype=gs.qd_float, shape=()), @@ -210,7 +204,6 @@ def get_rigid_global_info(solver, kinematic_only): geoms_init_AABB=V_VEC(3, dtype=gs.qd_float, shape=(solver.n_geoms_, 8)), mass_mat=V(dtype=gs.qd_float, shape=mass_mat_shape, layout=mass_mat_layout, needs_grad=requires_grad), mass_mat_L=V(dtype=gs.qd_float, shape=mass_mat_shape, needs_grad=requires_grad), - mass_mat_L_bw=V(dtype=gs.qd_float, shape=mass_mat_shape_bw, needs_grad=requires_grad), mass_mat_D_inv=V(dtype=gs.qd_float, shape=(solver.n_dofs_, _B), needs_grad=requires_grad), mass_mat_mask=V(dtype=gs.qd_bool, shape=(solver.n_entities_, _B)), mass_parent_mask=V(dtype=gs.qd_float, shape=(solver.n_dofs_, solver.n_dofs_)), @@ -496,6 +489,9 @@ class DiffContactInput: # Local positions of the 1 vertex from the two geometries that define the support point for the face above w_local_pos1: qd.Tensor w_local_pos2: qd.Tensor + # Plane-convex contacts only: the convex support "core" (box vertex / sphere + # center / capsule nearest endpoint) in the convex geom's local frame. + core_local: qd.Tensor # Reference id of the contact point, which is needed for the backward pass ref_id: qd.Tensor # Flag whether the contact data can be computed in numerically stable way in both the forward and backward passes @@ -518,6 +514,7 @@ def get_diff_contact_input(_B, max_contacts_per_pair, is_active, requires_grad=F local_pos2_c=V_VEC(3, dtype=gs.qd_float, shape=shape), w_local_pos1=V_VEC(3, dtype=gs.qd_float, shape=shape), w_local_pos2=V_VEC(3, dtype=gs.qd_float, shape=shape), + core_local=V_VEC(3, dtype=gs.qd_float, shape=shape), ref_id=V(dtype=gs.qd_int, shape=shape), valid=V(dtype=gs.qd_int, shape=shape), ref_penetration=V(dtype=gs.qd_float, shape=shape, needs_grad=True), diff --git a/genesis/utils/misc.py b/genesis/utils/misc.py index c4cf160bda..034da71ac1 100644 --- a/genesis/utils/misc.py +++ b/genesis/utils/misc.py @@ -561,7 +561,7 @@ def qd_to_torch( except AttributeError: try: tc = value.to_torch(copy=False) - except (ValueError, RuntimeError): + except (ValueError, RuntimeError, TypeError): if copy is False: raise tensor = _maybe_transpose(value.to_torch(), value, transpose) diff --git a/tests/test_grad.py b/tests/test_grad.py deleted file mode 100644 index d6276fd6e8..0000000000 --- a/tests/test_grad.py +++ /dev/null @@ -1,546 +0,0 @@ -import numpy as np -import pytest -import torch - -import genesis as gs -from genesis.utils.geom import R_to_quat -from genesis.utils.misc import qd_to_torch, qd_to_numpy, tensor_to_array -from genesis.utils import set_random_seed - -from .utils import assert_allclose - - -@pytest.mark.required -@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) -def test_differentiable_push(show_viewer): - HORIZON = 10 - - scene = gs.Scene( - sim_options=gs.options.SimOptions( - dt=2e-3, - substeps=10, - requires_grad=True, - ), - mpm_options=gs.options.MPMOptions( - lower_bound=(0.0, -1.0, 0.0), - upper_bound=(1.0, 1.0, 0.55), - ), - viewer_options=gs.options.ViewerOptions( - camera_pos=(2.5, -0.15, 2.42), - camera_lookat=(0.5, 0.5, 0.1), - ), - show_viewer=show_viewer, - ) - - plane = scene.add_entity( - gs.morphs.URDF( - file="urdf/plane/plane.urdf", - fixed=True, - ) - ) - stick = scene.add_entity( - morph=gs.morphs.Mesh( - file="meshes/stirrer.obj", - scale=0.6, - pos=(0.5, 0.5, 0.05), - euler=(90.0, 0.0, 0.0), - ), - material=gs.materials.Tool( - friction=8.0, - ), - ) - obj = scene.add_entity( - morph=gs.morphs.Box( - lower=(0.2, 0.1, 0.05), - upper=(0.4, 0.3, 0.15), - ), - material=gs.materials.MPM.Elastic( - rho=500, - ), - ) - scene.build(n_envs=2) - - init_pos = gs.tensor([[0.3, 0.1, 0.28], [0.3, 0.1, 0.5]], requires_grad=True) - stick.set_position(init_pos) - pos_obj_init = gs.tensor([0.3, 0.3, 0.1], requires_grad=True) - obj.set_position(pos_obj_init) - v_obj_init = gs.tensor([0.0, -1.0, 0.0], requires_grad=True) - obj.set_velocity(v_obj_init) - goal = gs.tensor([0.5, 0.8, 0.05]) - - loss = 0.0 - v_list = [] - for i in range(HORIZON): - v_i = gs.tensor([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]], requires_grad=True) - stick.set_velocity(vel=v_i) - v_list.append(v_i) - - scene.step() - - if i == HORIZON // 2: - mpm_particles = scene.get_state().solvers_state[scene.solvers.index(scene.mpm_solver)] - loss += torch.pow(mpm_particles.pos[mpm_particles.active == 1] - goal, 2).sum() - - if i == HORIZON - 2: - state = obj.get_state() - loss += torch.pow(state.pos - goal, 2).sum() - loss.backward() - - # TODO: It would be great to compare the gradient to its analytical or numerical value. - for v_i in v_list[:-1]: - assert (v_i.grad.abs() > gs.EPS).any() - assert (v_list[-1].grad.abs() < gs.EPS).all() - - -@pytest.mark.required -@pytest.mark.precision("64") -@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) -def test_diff_contact(): - RTOL = 1e-4 - - scene = gs.Scene( - sim_options=gs.options.SimOptions( - dt=0.01, - # Turn on differentiable mode - requires_grad=True, - ), - show_viewer=False, - ) - - box_size = 0.25 - box_spacing = box_size - vec_one = np.array([1.0, 1.0, 1.0]) - box_pos_offset = (0.0, 0.0, 0.0) + 0.5 * box_size * vec_one - - box0 = scene.add_entity( - gs.morphs.Box(size=box_size * vec_one, pos=box_pos_offset), - ) - box1 = scene.add_entity( - gs.morphs.Box(size=box_size * vec_one, pos=box_pos_offset + 0.8 * box_spacing * np.array([0, 0, 1])), - ) - scene.build() - solver = scene.sim.rigid_solver - collider = solver.collider - - # Set up initial configuration - x_ang, y_ang, z_ang = 3.0, 3.0, 3.0 - box1.set_quat(R_to_quat(gs.euler_to_R([np.deg2rad(x_ang), np.deg2rad(y_ang), np.deg2rad(z_ang)]))) - - box0_init_pos = box0.get_pos().clone() - box1_init_pos = box1.get_pos().clone() - box0_init_quat = box0.get_quat().clone() - box1_init_quat = box1.get_quat().clone() - - ### Compute the initial loss and compute gradients using differentiable contact detection - # Detect contact - collider.detection() - - # Get contact outputs and their grads - contacts = collider.get_contacts(as_tensor=True, to_torch=True, keep_batch_dim=True) - normal = contacts["normal"].requires_grad_() - position = contacts["position"].requires_grad_() - penetration = contacts["penetration"].requires_grad_() - - loss = ((normal * position).sum(dim=-1) * penetration).sum() - dL_dnormal = torch.autograd.grad(loss, normal, retain_graph=True)[0] - dL_dposition = torch.autograd.grad(loss, position, retain_graph=True)[0] - dL_dpenetration = torch.autograd.grad(loss, penetration)[0] - - # Compute analytical gradients of the geoms position and quaternion - collider.backward(dL_dposition, dL_dnormal, dL_dpenetration) - dL_dpos = qd_to_torch(solver.geoms_state.pos.grad) - dL_dquat = qd_to_torch(solver.geoms_state.quat.grad) - - ### Compute directional derivatives along random directions - FD_EPS = 1e-5 - TRIALS = 100 - - def compute_dL_error(dL_dx, x_type): - dL_error_rel = 0.0 - - box0_input_pos = box0_init_pos - box1_input_pos = box1_init_pos - box0_input_quat = box0_init_quat - box1_input_quat = box1_init_quat - - for _ in range(TRIALS): - rand_dx = torch.randn_like(dL_dx) - rand_dx = torch.nn.functional.normalize(rand_dx, dim=-1) - - dL = (rand_dx * dL_dx).sum() - - lossPs = [] - for sign in (1, -1): - # Compute query point - if x_type == "pos": - box0_input_pos = box0_init_pos + sign * rand_dx[0, 0] * FD_EPS - box1_input_pos = box1_init_pos + sign * rand_dx[1, 0] * FD_EPS - else: - # FIXME: The quaternion should be normalized - box0_input_quat = box0_init_quat + sign * rand_dx[0, 0] * FD_EPS - box1_input_quat = box1_init_quat + sign * rand_dx[1, 0] * FD_EPS - - # Update box positions - box0.set_pos(box0_input_pos) - box1.set_pos(box1_input_pos) - box0.set_quat(box0_input_quat) - box1.set_quat(box1_input_quat) - - # Re-detect contact. - # We need to manually reset the contact counter as we are not running the whole sim step. - collider._collider_state.n_contacts.fill(0) - collider.detection() - contacts = collider.get_contacts(as_tensor=True, to_torch=True, keep_batch_dim=True) - normal, position, penetration = contacts["normal"], contacts["position"], contacts["penetration"] - - # Compute loss - loss = ((normal * position).sum(dim=-1) * penetration).sum() - lossPs.append(loss) - - dL_fd = (lossPs[0] - lossPs[1]) / (2 * FD_EPS) - dL_error_rel += (dL - dL_fd).abs() / max(dL.abs(), dL_fd.abs(), gs.EPS) - - dL_error_rel /= TRIALS - return dL_error_rel - - dL_dpos_error_rel = compute_dL_error(dL_dpos, "pos") - assert_allclose(dL_dpos_error_rel, 0.0, atol=RTOL) - dL_dquat_error_rel = compute_dL_error(dL_dquat, "quat") - assert_allclose(dL_dquat_error_rel, 0.0, atol=RTOL) - - -# We need to use 64-bit precision for this test because we need to use sufficiently small perturbation to get reliable -# gradient estimates through finite difference method. This small perturbation is not supported by 32-bit precision in -# stable way. -@pytest.mark.required -@pytest.mark.precision("64") -@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) -def test_diff_solver(monkeypatch): - from genesis.engine.solvers.rigid.constraint.solver import func_solve_init, func_solve_body - from genesis.engine.solvers.rigid.rigid_solver import kernel_step_1 - - RTOL = 1e-4 - - scene = gs.Scene( - sim_options=gs.options.SimOptions( - dt=0.01, - requires_grad=True, - ), - rigid_options=gs.options.RigidOptions( - # We use Newton's method because it converges faster than CG, and therefore gives better gradient estimation - # when using finite difference method - constraint_solver=gs.constraint_solver.Newton, - ), - show_viewer=False, - ) - - scene.add_entity(gs.morphs.Plane(pos=(0, 0, 0))) - scene.add_entity(gs.morphs.Box(size=(1, 1, 1), pos=(10, 10, 0.49))) - franka = scene.add_entity( - gs.morphs.MJCF(file="xml/franka_emika_panda/panda.xml"), - ) - - scene.build() - rigid_solver = scene._sim.rigid_solver - constraint_solver = rigid_solver.constraint_solver - - franka.set_qpos([-1.0124, 1.5559, 1.3662, -1.6878, -1.5799, 1.7757, 1.4602, 0.04, 0.04]) - - # Monkeypatch the constraint resolve function to avoid overwriting the necessary information for computing gradients. - def constraint_solver_resolve(): - func_solve_init( - dofs_info=rigid_solver.dofs_info, - dofs_state=rigid_solver.dofs_state, - entities_info=rigid_solver.entities_info, - constraint_state=constraint_solver.constraint_state, - rigid_global_info=rigid_solver._rigid_global_info, - static_rigid_sim_config=rigid_solver._static_rigid_sim_config, - ) - func_solve_body( - entities_info=rigid_solver.entities_info, - dofs_info=rigid_solver.dofs_info, - dofs_state=rigid_solver.dofs_state, - constraint_state=constraint_solver.constraint_state, - rigid_global_info=rigid_solver._rigid_global_info, - static_rigid_sim_config=rigid_solver._static_rigid_sim_config, - _n_iterations=constraint_solver._n_iterations, - ) - - monkeypatch.setattr(constraint_solver, "resolve", constraint_solver_resolve) - - # Step once to compute constraint solver's inputs: [mass], [jac], [aref], [efc_D], [force]. We do not call the - # entire scene.step() because it will overwrite the necessary information that we need to compute the gradients. - kernel_step_1( - links_state=rigid_solver.links_state, - links_info=rigid_solver.links_info, - joints_state=rigid_solver.joints_state, - joints_info=rigid_solver.joints_info, - dofs_state=rigid_solver.dofs_state, - dofs_info=rigid_solver.dofs_info, - geoms_state=rigid_solver.geoms_state, - geoms_info=rigid_solver.geoms_info, - entities_state=rigid_solver.entities_state, - entities_info=rigid_solver.entities_info, - rigid_global_info=rigid_solver._rigid_global_info, - static_rigid_sim_config=rigid_solver._static_rigid_sim_config, - contact_island_state=constraint_solver.contact_island.contact_island_state, - is_forward_pos_updated=True, - is_forward_vel_updated=True, - is_backward=False, - ) - constraint_solver.add_equality_constraints() - rigid_solver.collider.detection() - constraint_solver.add_inequality_constraints() - constraint_solver.resolve() - - # Loss function to compute gradients using finite difference method - def compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force): - rigid_solver._rigid_global_info.mass_mat.from_numpy(input_mass) - constraint_solver.constraint_state.jac.from_numpy(input_jac) - constraint_solver.constraint_state.aref.from_numpy(input_aref) - constraint_solver.constraint_state.efc_D.from_numpy(input_efc_D) - rigid_solver.dofs_state.force.from_numpy(input_force) - - # Recompute acc_smooth from the updated input variables - updated_acc_smooth = np.linalg.solve(input_mass[..., 0], input_force[..., 0]) - rigid_solver.dofs_state.acc_smooth.from_numpy(updated_acc_smooth[..., None]) - constraint_solver.resolve() - - output_qacc = qd_to_torch(constraint_solver.qacc) - return ((output_qacc - target_qacc) ** 2).mean() - - init_input_mass = qd_to_numpy(rigid_solver._rigid_global_info.mass_mat, copy=True) - init_input_jac = qd_to_numpy(constraint_solver.constraint_state.jac, copy=True) - init_input_aref = qd_to_numpy(constraint_solver.constraint_state.aref, copy=True) - init_input_efc_D = qd_to_numpy(constraint_solver.constraint_state.efc_D, copy=True) - init_input_force = qd_to_numpy(rigid_solver.dofs_state.force, copy=True) - - # Initial output of the constraint solver - set_random_seed(0) - init_output_qacc = qd_to_torch(constraint_solver.qacc) - target_qacc = torch.from_numpy(np.random.randn(*init_output_qacc.shape)).to(device=gs.device) - target_qacc = target_qacc * init_output_qacc.abs().mean() - - # Solve the constraint solver and get the output - output_qacc = qd_to_torch(constraint_solver.qacc, copy=True).requires_grad_(True) - - # Compute loss and gradient of the output - loss = ((output_qacc - target_qacc) ** 2).mean() - dL_dqacc = tensor_to_array(torch.autograd.grad(loss, output_qacc)[0]) - - # Compute gradients of the input variables: [mass], [jac], [aref], [efc_D], [force] - constraint_solver.backward(dL_dqacc) - - # Fetch gradients of the input variables - dL_dM = qd_to_numpy(constraint_solver.constraint_state.dL_dM) - dL_djac = qd_to_numpy(constraint_solver.constraint_state.dL_djac) - dL_daref = qd_to_numpy(constraint_solver.constraint_state.dL_daref) - dL_defc_D = qd_to_numpy(constraint_solver.constraint_state.dL_defc_D) - dL_dforce = qd_to_numpy(constraint_solver.constraint_state.dL_dforce) - - ### Compute directional derivatives along random directions - FD_EPS = 1e-3 - TRIALS = 200 - - for dL_dx, x_type in ( - (dL_dforce, "force"), - (dL_daref, "aref"), - (dL_defc_D, "efc_D"), - (dL_djac, "jac"), - (dL_dM, "mass"), - ): - dL_error = 0.0 - for _ in range(TRIALS): - rand_dx = np.random.randn(*dL_dx.shape) - rand_dx = rand_dx / max( - np.linalg.norm(rand_dx, axis=0 if x_type in ("force", "aref", "efc_D") else (0, 1)), gs.EPS - ) - if x_type == "mass": - # Make rand_dx symmetric - rand_dx = (rand_dx + np.moveaxis(rand_dx, 0, 1)) * 0.5 - - dL = (rand_dx * dL_dx).sum() - - input_force = init_input_force - input_aref = init_input_aref - input_efc_D = init_input_efc_D - input_jac = init_input_jac - input_mass = init_input_mass - - # 1 * eps - if x_type == "force": - input_force = init_input_force + rand_dx * FD_EPS - elif x_type == "aref": - input_aref = init_input_aref + rand_dx * FD_EPS - elif x_type == "efc_D": - input_efc_D = init_input_efc_D + rand_dx * FD_EPS - elif x_type == "jac": - input_jac = init_input_jac + rand_dx * FD_EPS - elif x_type == "mass": - input_mass = init_input_mass + rand_dx * FD_EPS - lossP1 = compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force) - - # -1 * eps - if x_type == "force": - input_force = init_input_force - rand_dx * FD_EPS - elif x_type == "aref": - input_aref = init_input_aref - rand_dx * FD_EPS - elif x_type == "efc_D": - input_efc_D = init_input_efc_D - rand_dx * FD_EPS - elif x_type == "jac": - input_jac = init_input_jac - rand_dx * FD_EPS - elif x_type == "mass": - input_mass = init_input_mass - rand_dx * FD_EPS - - lossP2 = compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force) - dL_fd = (lossP1 - lossP2) / (2 * FD_EPS) - - dL_error += (dL - dL_fd).abs() / max(abs(dL), abs(dL_fd), gs.EPS) - - dL_error /= TRIALS - assert_allclose(dL_error, 0.0, atol=RTOL) - - -@pytest.mark.slow # ~250s -@pytest.mark.required -@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) -def test_differentiable_rigid(show_viewer): - dt = 1e-2 - horizon = 100 - substeps = 1 - goal_pos = gs.tensor([0.7, 1.0, 0.05]) - goal_quat = gs.tensor([0.3, 0.2, 0.1, 0.9]) - goal_quat = goal_quat / torch.norm(goal_quat, dim=-1, keepdim=True) - - scene = gs.Scene( - sim_options=gs.options.SimOptions( - dt=dt, - substeps=substeps, - requires_grad=True, - gravity=(0, 0, -1), - ), - rigid_options=gs.options.RigidOptions( - enable_collision=False, - enable_self_collision=False, - enable_joint_limit=False, - disable_constraint=True, - use_contact_island=False, - use_hibernation=False, - ), - viewer_options=gs.options.ViewerOptions( - camera_pos=(2.5, -0.15, 2.42), - camera_lookat=(0.5, 0.5, 0.1), - ), - show_viewer=show_viewer, - ) - - box = scene.add_entity( - gs.morphs.Box( - pos=(0, 0, 0), - size=(0.1, 0.1, 0.2), - ), - surface=gs.surfaces.Default( - color=(0.9, 0.0, 0.0, 1.0), - ), - ) - if show_viewer: - target = scene.add_entity( - gs.morphs.Box( - pos=goal_pos, - quat=goal_quat, - size=(0.1, 0.1, 0.2), - ), - surface=gs.surfaces.Default( - color=(0.0, 0.9, 0.0, 0.5), - ), - ) - - scene.build() - - num_iter = 200 - lr = 1e-2 - - init_pos = gs.tensor([0.3, 0.1, 0.28], requires_grad=True) - init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) - optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr) - - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) - - for _ in range(num_iter): - scene.reset() - - box.set_pos(init_pos) - box.set_quat(init_quat) - - loss = 0 - for _ in range(horizon): - scene.step() - if show_viewer: - target.set_pos(goal_pos) - target.set_quat(goal_quat) - - box_state = box.get_state() - box_pos = box_state.pos - box_quat = box_state.quat - loss = torch.abs(box_pos - goal_pos).sum() + torch.abs(box_quat - goal_quat).sum() - - optimizer.zero_grad() - loss.backward() # this lets gradient flow all the way back to tensor input - optimizer.step() - scheduler.step() - - with torch.no_grad(): - init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) - - assert_allclose(loss, 0.0, atol=1e-2) - - -@pytest.mark.required -@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) -def test_diff_sim_vs_solver_state_grad_parity(show_viewer): - scene = gs.Scene( - sim_options=gs.options.SimOptions( - dt=0.01, - gravity=(0.0, 0.0, 0.0), - requires_grad=True, - ), - rigid_options=gs.options.RigidOptions( - enable_collision=False, - ), - show_viewer=show_viewer, - ) - robot = scene.add_entity( - gs.morphs.Box( - size=(0.1, 0.1, 0.1), - pos=(0, 0, 0), - ) - ) - scene.build() - - ctrl = gs.tensor(np.random.randn(robot.n_dofs), dtype=gs.tc_float, requires_grad=True) - - grads = [] - for use_sim_state in (False, True): - scene.reset() - - robot.set_dofs_velocity(ctrl) - scene.step() - - if use_sim_state: - solver_state = scene.get_state().solvers_state[scene.solvers.index(scene.rigid_solver)] - chassis_pos = solver_state.links_pos[:, 0].squeeze() - else: - chassis_pos = robot.get_state().pos.squeeze() - - loss = torch.linalg.norm(chassis_pos) - loss.backward() - grad = ctrl.grad.detach().clone() - ctrl.grad.zero_() - - # Basic sanity check - assert (grad[..., :3].abs() > gs.EPS).all() - assert (grad[..., 3:].abs() < gs.EPS).all() - - grads.append(grad) - - assert_allclose(*grads, atol=gs.EPS) diff --git a/tests/test_grad_fd.py b/tests/test_grad_fd.py new file mode 100644 index 0000000000..536ea9cac2 --- /dev/null +++ b/tests/test_grad_fd.py @@ -0,0 +1,1932 @@ +"""FD-vs-analytical gradient checks for the differentiable rigid solver. + +This is the *correctness* layer of the grad test suite: every test compares the +diff-mode analytical gradient against central finite differences. It is split +into three sections, innermost backward layer first: + + 1. Forward-kinematics (constraints OFF): the unconstrained FK + velocity + gradient, per joint topology — the base local-gradient bar. + 2. Joint-limit (enable_joint_limit=True): the joint-limit inequality + constraint reverse, plus forward enforcement. + 3. Contact (enable_collision=True): the collision constraint + + diff-GJK narrow-phase reverse (box-box and plane-convex). + +Each section keeps its OWN scene builder — the FK builder disables all +constraints, while the joint-limit / contact builders turn the relevant +constraint on — so the configs never bleed across sections. +""" + +import math + +import numpy as np +import pytest +import torch + +import genesis as gs +from genesis.utils import set_random_seed +from genesis.utils.geom import R_to_quat +from genesis.utils.misc import qd_to_numpy, qd_to_torch, tensor_to_array + +from .utils import assert_allclose + + +pytestmark = [ + pytest.mark.debug(False), +] + + +# Per-precision FD tolerance. fp32 is intentionally looser. +# The "quat" kind covers outputs that go through a non-linear pose composition +# (set_dofs_velocity → state.quat) where Genesis autograd is currently a ~1% +# noisier than FD. +_TOL = { + ("64", "default"): dict(rtol=1e-4, atol=1e-6, eps=1e-5), + ("64", "quat"): dict(rtol=2e-2, atol=1e-3, eps=1e-5), + ("32", "default"): dict(rtol=2e-2, atol=2e-3, eps=1e-3), + ("32", "quat"): dict(rtol=5e-2, atol=5e-3, eps=1e-3), +} + + +_PRECISION_PARAMS = [ + pytest.param("64", marks=pytest.mark.precision("64"), id="fp64"), + pytest.param("32", marks=pytest.mark.precision("32"), id="fp32"), +] + +_N_ENVS_PARAMS = [ + pytest.param(0, id="single"), + pytest.param(4, id="batched"), +] + +_SUBSTEPS_PARAMS = [ + pytest.param(1, id="ss1"), + pytest.param(4, id="ss4"), +] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_scene(mjcf_path: str, *, requires_grad: bool, n_envs: int = 0, substeps: int = 1): + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=0.01, + substeps=substeps, + gravity=(0.0, 0.0, -9.81), + requires_grad=requires_grad, + ), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + enable_self_collision=False, + enable_joint_limit=False, + disable_constraint=True, + use_hibernation=False, + use_contact_island=False, + ), + show_viewer=False, + ) + robot = scene.add_entity(gs.morphs.MJCF(file=mjcf_path)) + scene.build(n_envs=n_envs) + return scene, robot + + +def _make_scene_pair(mjcf_file: str, n_envs: int = 0, substeps: int = 1): + """Build two parallel scenes from the same MJCF: + + * `scene_ana` runs the differentiable-mode forward and is the only one we + ever call `loss.backward()` on. Once a backward has run, that scene's + internal target-replay state is left in a configuration that silently + ignores subsequent setters — so reusing it for FD probes would give + loss_p == loss_m and a fake zero gradient. + * `scene_fd` runs the production forward (`requires_grad=False`) and is + what FD perturbs. By construction it never sees a backward, so each + reset → set → step cycle is clean. + + FD therefore checks "does the diff-mode analytical gradient match the + production forward's local sensitivity". With `n_envs > 0` both scenes + run in batched mode so we can verify that per-env adjoints are + independently correct. + """ + scene_ana, robot_ana = _build_scene(mjcf_file, requires_grad=True, n_envs=n_envs, substeps=substeps) + scene_fd, robot_fd = _build_scene(mjcf_file, requires_grad=False, n_envs=n_envs, substeps=substeps) + return scene_ana, robot_ana, scene_fd, robot_fd, mjcf_file + + +def _batch_size(scene) -> int: + """Effective batch dimension. scene.n_envs == 0 still allocates B=1 internally.""" + return scene.n_envs if scene.n_envs > 0 else 1 + + +def _input_shape(base_shape, n_envs): + """Setter inputs are unbatched when n_envs==0; batched (n_envs, *base) otherwise.""" + return (n_envs,) + tuple(base_shape) if n_envs > 0 else tuple(base_shape) + + +def _solver_state(scene): + """Return the rigid solver's RigidSolverState (grad-aware; provides + links_pos / links_quat for every link in the entity).""" + state = scene.get_state() + return state.solvers_state[scene.solvers.index(scene.rigid_solver)] + + +def _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input, # 1-D numpy array (fp64) + apply_fn, # callable(robot, x): apply x via a @tracked setter + loss_fn, # callable(scene, robot) -> scalar tensor + *, + label: str, + rtol: float = 1e-4, + atol: float = 1e-6, + eps: float = 1e-5, +): + # NOTE on tolerances: the production-mode and diff-mode forward kernels were + # verified to produce bit-identical state.pos/state.quat for the same input + # (probe_optionB.py, 2026-05-03), so an FD probed on the no-grad scene + # is a valid reference for the diff scene's analytical gradient. + # + # Most input/output pairs hit rtol=1e-4 trivially. The set_dofs_velocity → + # state.quat path is the outlier: it carries a known ~1% systematic drift + # between Genesis autograd and central FD (output magnitude is ~1e-2, so + # the absolute mismatch sits at ~1e-4 — well above truncation/roundoff + # at fp64). Tracked as a separate followup; for those cases callers should + # pass a looser `rtol` (e.g. 2e-2) rather than tightening this default. + base_np = np.asarray(init_input, dtype=np.float64).copy() + + # --- analytical (diff-mode scene) --- + x_ana = gs.tensor(base_np, dtype=gs.tc_float, requires_grad=True) + scene_ana.reset() + apply_fn(robot_ana, x_ana) + scene_ana.step() + loss = loss_fn(scene_ana, robot_ana) + assert loss.requires_grad, f"[{label}] loss does not require grad — output is not grad-aware" + loss.backward() + assert x_ana.grad is not None, f"[{label}] x.grad is None after backward" + ana_grad = x_ana.grad.detach().cpu().numpy().copy() + + # --- central FD (production-mode scene) --- + n = base_np.size + fd_grad = np.zeros_like(base_np) + for i in range(n): + plus = base_np.copy() + plus.reshape(-1)[i] = base_np.reshape(-1)[i] + eps + scene_fd.reset() + apply_fn(robot_fd, gs.tensor(plus, dtype=gs.tc_float)) + scene_fd.step() + loss_p = float(loss_fn(scene_fd, robot_fd).detach().cpu()) + + minus = base_np.copy() + minus.reshape(-1)[i] = base_np.reshape(-1)[i] - eps + scene_fd.reset() + apply_fn(robot_fd, gs.tensor(minus, dtype=gs.tc_float)) + scene_fd.step() + loss_m = float(loss_fn(scene_fd, robot_fd).detach().cpu()) + + fd_grad.reshape(-1)[i] = (loss_p - loss_m) / (2.0 * eps) + + assert_allclose( + torch.from_numpy(ana_grad), + torch.from_numpy(fd_grad), + rtol=rtol, + atol=atol, + err_msg=f"[{label}] FD vs analytical mismatch", + ) + + +def _grad_matches_fd_multistep( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_inputs, # list[np.ndarray] — one input per timestep, each shape matches the setter's expectation + apply_fn, # callable(robot, x): apply x via a @tracked setter + loss_fn, # callable(scene, robot) -> scalar tensor + *, + label: str, + rtol: float = 1e-4, + atol: float = 1e-6, + eps: float = 1e-5, +): + """Multi-step variant of `_grad_matches_fd`. + + Forwards `N = len(init_inputs)` simulator steps, applying a different + `@tracked`-setter input at each step. After `loss.backward()`, the + simulator must produce a correct adjoint for each step's input + independently (i.e. `scene._backward()` correctly walks the per-substep + `process_input_grad` chain). + + The FD reference perturbs each entry of each step's input separately + and re-runs the full N-step trajectory on `scene_fd`. Cost is + O(N · sum_inputs_size) forward runs of N steps each; with N=10 and + n_dofs ~ 3-7 this is ~600-1400 step calls per topology. + """ + N = len(init_inputs) + base_np = [np.asarray(inp, dtype=np.float64).copy() for inp in init_inputs] + + # --- analytical (diff-mode scene) --- + scene_ana.reset() + x_anas = [] + for t in range(N): + x = gs.tensor(base_np[t], dtype=gs.tc_float, requires_grad=True) + x_anas.append(x) + apply_fn(robot_ana, x) + scene_ana.step() + loss = loss_fn(scene_ana, robot_ana) + assert loss.requires_grad, f"[{label}] loss does not require grad — output is not grad-aware" + loss.backward() + ana_grads = [] + for t, x in enumerate(x_anas): + assert x.grad is not None, f"[{label}] step {t}: x.grad is None after backward" + ana_grads.append(x.grad.detach().cpu().numpy().copy()) + + # --- central FD (production-mode scene): for each (t, i) entry, run the + # full N-step trajectory twice with the perturbation injected only at + # step t. All other steps use the original input. + fd_grads = [np.zeros_like(b) for b in base_np] + + def _run_traj_with_perturb(t_perturb, i_perturb, sign): + scene_fd.reset() + for s in range(N): + inp = base_np[s].copy() + if s == t_perturb: + inp.reshape(-1)[i_perturb] += sign * eps + apply_fn(robot_fd, gs.tensor(inp, dtype=gs.tc_float)) + scene_fd.step() + return float(loss_fn(scene_fd, robot_fd).detach().cpu()) + + for t in range(N): + for i in range(base_np[t].size): + loss_p = _run_traj_with_perturb(t, i, +1) + loss_m = _run_traj_with_perturb(t, i, -1) + fd_grads[t].reshape(-1)[i] = (loss_p - loss_m) / (2.0 * eps) + + for t in range(N): + assert_allclose( + torch.from_numpy(ana_grads[t]), + torch.from_numpy(fd_grads[t]), + rtol=rtol, + atol=atol, + err_msg=f"[{label}] step {t}: FD vs analytical mismatch", + ) + + +# loss factories — all use sum-of-squared-deviation to a fixed random target so +# every entry of the input has a nontrivial sensitivity. Targets and outputs are +# both flattened before the subtraction so multi-link shapes (B, n_links, 3|4) +# don't trip torch broadcasting. +def _loss_state_pos(target): + flat = target.reshape(-1) + + def _fn(scene, robot): + return ((robot.get_state().pos.reshape(-1) - flat) ** 2).sum() + + return _fn + + +def _loss_state_quat(target): + flat = target.reshape(-1) + + def _fn(scene, robot): + return ((robot.get_state().quat.reshape(-1) - flat) ** 2).sum() + + return _fn + + +def _loss_links_pos(target): + flat = target.reshape(-1) + + def _fn(scene, robot): + return ((_solver_state(scene).links_pos.reshape(-1) - flat) ** 2).sum() + + return _fn + + +def _loss_links_quat(target): + flat = target.reshape(-1) + + def _fn(scene, robot): + return ((_solver_state(scene).links_quat.reshape(-1) - flat) ** 2).sum() + + return _fn + + +def _rand_np(shape, seed): + rng = np.random.default_rng(seed) + return rng.standard_normal(shape).astype(np.float64) + + +def _target(shape, seed): + return torch.from_numpy(_rand_np(shape, seed)).to(dtype=gs.tc_float, device=gs.device) + + +# --------------------------------------------------------------------------- +# Tests — one per joint topology, several (input, output) checks inside. +# --------------------------------------------------------------------------- + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("precision", _PRECISION_PARAMS) +@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS) +@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS) +def test_diff_fk_freejoint(show_viewer, n_envs, precision, substeps): + """J1: single free body. Covers (n_envs ∈ {0, 4}) × (precision ∈ {fp64, fp32}).""" + scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair( + "xml/grad/free.xml", n_envs=n_envs, substeps=substeps + ) + n_dofs = robot_ana.n_dofs + B = _batch_size(scene_ana) + tol_default = _TOL[(precision, "default")] + tol_quat = _TOL[(precision, "quat")] + + tgt_pos = _target((B, 3), seed=1) + tgt_quat = _target((B, 4), seed=2) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((3,), n_envs), seed=10), + apply_fn=lambda r, x: r.set_pos(x), + loss_fn=_loss_state_pos(tgt_pos), + label="J1 set_pos → state.pos", + **tol_default, + ) + + init_q_shape = _input_shape((4,), n_envs) + init_q = np.broadcast_to(np.array([1.0, 0.0, 0.0, 0.0]), init_q_shape).copy() + init_q = init_q + 0.05 * _rand_np(init_q_shape, seed=11) + init_q = init_q / np.linalg.norm(init_q, axis=-1, keepdims=True) + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=init_q, + apply_fn=lambda r, x: r.set_quat(x), + loss_fn=_loss_state_quat(tgt_quat), + label="J1 set_quat → state.quat", + **tol_quat, + ) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=12), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_state_pos(tgt_pos), + label="J1 set_dofs_velocity → state.pos (after 1 step)", + **tol_default, + ) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=13), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_state_quat(tgt_quat), + label="J1 set_dofs_velocity → state.quat (after 1 step)", + **tol_quat, + ) + + # fp64 only: d(state.pos)/d(force) ≈ dt^2 / (2 * inertia) ≈ 1e-4 after 1 + # step. At fp32 with FD eps=1e-3 the loss difference is ~1e-7 — at fp32's + # precision floor — and the FD probe disagrees with analytical by ~1e-4 + # absolute, well above the fp32 default tol band. The J2/J3/J4/J5 force + # checks below are also fp64-only for the same reason; J2's + # `control_dofs_force → state.quat` does pass at fp32 only because its + # check uses the wider quat tolerance. + if precision == "64": + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=14), + apply_fn=lambda r, x: r.control_dofs_force(x), + loss_fn=_loss_state_pos(tgt_pos), + label="J1 control_dofs_force → state.pos (after 1 step)", + **tol_default, + ) + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("precision", _PRECISION_PARAMS) +@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS) +@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS) +def test_diff_fk_revolute(show_viewer, n_envs, precision, substeps): + """J2: single revolute joint, fixed base. Covers (n_envs ∈ {0, 4}) × (precision ∈ {fp64, fp32}).""" + scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair( + "xml/grad/revolute.xml", n_envs=n_envs, substeps=substeps + ) + n_dofs = robot_ana.n_dofs # = 1 + B = _batch_size(scene_ana) + tol_default = _TOL[(precision, "default")] + tol_quat = _TOL[(precision, "quat")] + + tgt_pos = _target((B, 3), seed=21) + tgt_quat = _target((B, 4), seed=22) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=30), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_state_pos(tgt_pos), + label="J2 set_dofs_velocity → state.pos", + **tol_default, + ) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=31), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_state_quat(tgt_quat), + label="J2 set_dofs_velocity → state.quat", + **tol_quat, + ) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=32), + apply_fn=lambda r, x: r.control_dofs_force(x), + loss_fn=_loss_state_quat(tgt_quat), + label="J2 control_dofs_force → state.quat", + **tol_quat, + ) + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("precision", _PRECISION_PARAMS) +@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS) +@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS) +def test_diff_fk_spherical(show_viewer, n_envs, precision, substeps): + """J6: single spherical (ball) joint, fixed base. 3 angular DOFs / 4 qpos + (quaternion). Exercises the SPHERICAL branch of + `kernel_manual_forward_kinematics_bw` — verifies that the `qloc → qpos[0:4]` + chain rule matches FD.""" + scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair( + "xml/grad/spherical.xml", n_envs=n_envs, substeps=substeps + ) + n_dofs = robot_ana.n_dofs # = 3 + B = _batch_size(scene_ana) + tol_default = _TOL[(precision, "default")] + tol_quat = _TOL[(precision, "quat")] + + tgt_pos = _target((B, 3), seed=61) + tgt_quat = _target((B, 4), seed=62) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=70), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_state_pos(tgt_pos), + label="J6 set_dofs_velocity → state.pos", + **tol_default, + ) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=71), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_state_quat(tgt_quat), + label="J6 set_dofs_velocity → state.quat", + **tol_quat, + ) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=72), + apply_fn=lambda r, x: r.control_dofs_force(x), + loss_fn=_loss_state_quat(tgt_quat), + label="J6 control_dofs_force → state.quat", + **tol_quat, + ) + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("precision", _PRECISION_PARAMS) +@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS) +@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS) +def test_diff_fk_prismatic(show_viewer, n_envs, precision, substeps): + """J3: single prismatic joint, fixed base. No rotational DOF — skip the quat output.""" + scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair( + "xml/grad/prismatic.xml", n_envs=n_envs, substeps=substeps + ) + n_dofs = robot_ana.n_dofs # = 1 + B = _batch_size(scene_ana) + tol_default = _TOL[(precision, "default")] + tgt_pos = _target((B, 3), seed=41) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=50), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_state_pos(tgt_pos), + label="J3 set_dofs_velocity → state.pos", + **tol_default, + ) + + # fp64-only — see J1's control_dofs_force comment for why FD-vs-analytical + # on force-driven position is at fp32's precision floor. + if precision == "64": + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=51), + apply_fn=lambda r, x: r.control_dofs_force(x), + loss_fn=_loss_state_pos(tgt_pos), + label="J3 control_dofs_force → state.pos", + **tol_default, + ) + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("precision", _PRECISION_PARAMS) +@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS) +@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS) +def test_diff_fk_cartpole(show_viewer, n_envs, precision, substeps): + """J7: cartpole (prismatic cart + revolute pole). 2 links / 2 DOFs. + Same chain as J4 (multi-link entity with translation + rotation + coupling).""" + scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair("xml/cartpole.xml", n_envs=n_envs, substeps=substeps) + n_dofs = robot_ana.n_dofs # = 2 (slider + hinge) + B = _batch_size(scene_ana) + tol_default = _TOL[(precision, "default")] + tol_quat = _TOL[(precision, "quat")] + + tgt_links_pos = _target((B, 2, 3), seed=181) + tgt_links_quat = _target((B, 2, 4), seed=182) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=190), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_links_pos(tgt_links_pos), + label="J7 set_dofs_velocity → links_pos", + **tol_default, + ) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=191), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_links_quat(tgt_links_quat), + label="J7 set_dofs_velocity → links_quat", + **tol_quat, + ) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=192), + apply_fn=lambda r, x: r.control_dofs_force(x), + loss_fn=_loss_links_pos(tgt_links_pos), + label="J7 control_dofs_force → links_pos", + **tol_default, + ) + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("precision", _PRECISION_PARAMS) +@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS) +@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS) +def test_diff_fk_hopper(show_viewer, n_envs, precision, substeps): + """J8: hopper, built collision-free. 4 links / 6 DOFs (planar floating base + rootx+rootz+rooty, then thigh/leg/foot hinges).""" + scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair("xml/hopper.xml", n_envs=n_envs, substeps=substeps) + n_dofs = robot_ana.n_dofs # = 6 (rootx, rootz, rooty, thigh, leg, foot) + n_links = robot_ana.n_links # = 5 (base + torso, thigh, leg, foot) + B = _batch_size(scene_ana) + tol_default = _TOL[(precision, "default")] + tol_quat = _TOL[(precision, "quat")] + # Hopper is the largest topology here (5 links, 6 DOFs). At fp32 the batched + # (n_envs=4) FD probe quantizes the small-sensitivity links_pos/links_quat + # entries to a ~2e-3 step, leaving a few entries ~6e-3 from the analytical. + # fp64 (single + batched) pins correctness; widen only the fp32 atol band so + # the FD-floor noise on the larger chain doesn't trip the check. + if precision == "32": + tol_default = dict(rtol=tol_default["rtol"], atol=8e-3, eps=tol_default["eps"]) + tol_quat = dict(rtol=tol_quat["rtol"], atol=8e-3, eps=tol_quat["eps"]) + + tgt_links_pos = _target((B, n_links, 3), seed=201) + tgt_links_quat = _target((B, n_links, 4), seed=202) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=210), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_links_pos(tgt_links_pos), + label="J8 set_dofs_velocity → links_pos", + **tol_default, + ) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=211), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_links_quat(tgt_links_quat), + label="J8 set_dofs_velocity → links_quat", + **tol_quat, + ) + + # fp64-only — see J1's control_dofs_force comment. + if precision == "64": + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=212), + apply_fn=lambda r, x: r.control_dofs_force(x), + loss_fn=_loss_links_pos(tgt_links_pos), + label="J8 control_dofs_force → links_pos", + **tol_default, + ) + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("precision", _PRECISION_PARAMS) +@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS) +@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS) +def test_diff_fk_free_with_revolute(show_viewer, n_envs, precision, substeps): + """J4: freejoint root + one revolute child. Outputs use + multi-link solver_state.links_pos/quat so the child link's FK is exercised too.""" + scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair( + "xml/grad/free_with_revolute.xml", n_envs=n_envs, substeps=substeps + ) + n_dofs = robot_ana.n_dofs # 6 free + 1 hinge = 7 + n_links = robot_ana.n_links # 2 + B = _batch_size(scene_ana) + tol_default = _TOL[(precision, "default")] + tol_quat = _TOL[(precision, "quat")] + tgt_links_pos = _target((B, n_links, 3), seed=61) + tgt_links_quat = _target((B, n_links, 4), seed=62) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((3,), n_envs), seed=70), + apply_fn=lambda r, x: r.set_pos(x), + loss_fn=_loss_links_pos(tgt_links_pos), + label="J4 set_pos → links_pos", + **tol_default, + ) + + init_q_shape = _input_shape((4,), n_envs) + init_q = np.broadcast_to(np.array([1.0, 0.0, 0.0, 0.0]), init_q_shape).copy() + init_q = init_q + 0.05 * _rand_np(init_q_shape, seed=71) + init_q = init_q / np.linalg.norm(init_q, axis=-1, keepdims=True) + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=init_q, + apply_fn=lambda r, x: r.set_quat(x), + loss_fn=_loss_links_quat(tgt_links_quat), + label="J4 set_quat → links_quat", + **tol_quat, + ) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=72), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_links_pos(tgt_links_pos), + label="J4 set_dofs_velocity → links_pos", + **tol_default, + ) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=73), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_links_quat(tgt_links_quat), + label="J4 set_dofs_velocity → links_quat", + **tol_quat, + ) + + # fp64-only — see J1's control_dofs_force comment. + if precision == "64": + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=74), + apply_fn=lambda r, x: r.control_dofs_force(x), + loss_fn=_loss_links_pos(tgt_links_pos), + label="J4 control_dofs_force → links_pos", + **tol_default, + ) + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("precision", _PRECISION_PARAMS) +@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS) +@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS) +def test_diff_fk_revolute_chain3(show_viewer, n_envs, precision, substeps): + """J5: 3-link serial revolute chain, fixed base. Tests deeper FK chain.""" + scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair( + "xml/grad/revolute_chain3.xml", n_envs=n_envs, substeps=substeps + ) + n_dofs = robot_ana.n_dofs # 3 + n_links = robot_ana.n_links # 3 + B = _batch_size(scene_ana) + tol_default = _TOL[(precision, "default")] + tol_quat = _TOL[(precision, "quat")] + tgt_links_pos = _target((B, n_links, 3), seed=81) + tgt_links_quat = _target((B, n_links, 4), seed=82) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=90), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_links_pos(tgt_links_pos), + label="J5 set_dofs_velocity → links_pos", + **tol_default, + ) + + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=91), + apply_fn=lambda r, x: r.set_dofs_velocity(x), + loss_fn=_loss_links_quat(tgt_links_quat), + label="J5 set_dofs_velocity → links_quat", + **tol_quat, + ) + + # fp64-only — see J1's control_dofs_force comment. + if precision == "64": + _grad_matches_fd( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_input=_rand_np(_input_shape((n_dofs,), n_envs), seed=92), + apply_fn=lambda r, x: r.control_dofs_force(x), + loss_fn=_loss_links_pos(tgt_links_pos), + label="J5 control_dofs_force → links_pos", + **tol_default, + ) + + +# --------------------------------------------------------------------------- +# Multi-step gradient verification — exercises cross-step adjoint propagation. +# --------------------------------------------------------------------------- + + +_MULTISTEP_TOPOLOGIES = [ + pytest.param("xml/grad/free.xml", "J1 freejoint", 6, _loss_state_pos, (3,), 161, id="J1_free"), + pytest.param("xml/grad/revolute.xml", "J2 revolute", 1, _loss_state_pos, (3,), 162, id="J2_revolute"), + pytest.param("xml/grad/prismatic.xml", "J3 prismatic", 1, _loss_state_pos, (3,), 163, id="J3_prismatic"), + pytest.param( + "xml/grad/free_with_revolute.xml", "J4 free+revolute", 7, _loss_links_pos, (2, 3), 164, id="J4_free_rev" + ), + pytest.param("xml/grad/revolute_chain3.xml", "J5 chain3", 3, _loss_links_pos, (3, 3), 165, id="J5_chain3"), + pytest.param("xml/grad/spherical.xml", "J6 spherical", 3, _loss_state_pos, (3,), 166, id="J6_spherical"), + pytest.param("xml/cartpole.xml", "J7 cartpole", 2, _loss_links_pos, (2, 3), 167, id="J7_cartpole"), + pytest.param("xml/hopper.xml", "J8 hopper", 6, _loss_links_pos, (5, 3), 168, id="J8_hopper"), +] + + +@pytest.mark.required +@pytest.mark.precision("64") +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("mjcf_str, name, n_dofs, loss_factory, output_shape, seed", _MULTISTEP_TOPOLOGIES) +@pytest.mark.parametrize("substeps", _SUBSTEPS_PARAMS) +def test_diff_fk_multistep_control_force( + show_viewer, mjcf_str, name, n_dofs, loss_factory, output_shape, seed, substeps +): + """Per-topology check that `control_dofs_force` applied with a *different* + input at each of N=10 steps produces per-step gradients that match FD. + + fp64 + single env only: N=10 with batched + fp32 makes the test slow + (~30s per topology) and the fp32 + batched ulps-level noise across + multiple steps stacks up enough to require relaxed tolerances that + obscure real bugs. Single-step tests already cover fp32 + batched + against the same setter. + """ + scene_ana, robot_ana, scene_fd, robot_fd, _ = _make_scene_pair(mjcf_str, n_envs=0, substeps=substeps) + B = _batch_size(scene_ana) + target = _target((B, *output_shape), seed=seed) + + # 10 distinct force inputs, one per step. + N = 10 + init_inputs = [_rand_np((n_dofs,), seed=seed * 100 + t) for t in range(N)] + + _grad_matches_fd_multistep( + scene_ana, + robot_ana, + scene_fd, + robot_fd, + init_inputs=init_inputs, + apply_fn=lambda r, x: r.control_dofs_force(x), + loss_fn=loss_factory(target), + label=f"{name} control_dofs_force × {N} steps", + ) + + +# =========================================================================== +# Joint-limit constraint FD (enable_joint_limit=True -> constraints ON) +# =========================================================================== + + +def _build(mjcf_path: str, *, requires_grad: bool, enable_joint_limit: bool): + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=1.0 / 60.0, + substeps=4, + gravity=(0.0, 0.0, 0.0), + requires_grad=requires_grad, + ), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + enable_self_collision=False, + enable_joint_limit=enable_joint_limit, + disable_constraint=not enable_joint_limit, + use_hibernation=False, + use_contact_island=False, + ), + show_viewer=False, + ) + robot = scene.add_entity(gs.morphs.MJCF(file=mjcf_path)) + scene.build(n_envs=0) + return scene, robot + + +def _rigid_state(scene): + return scene.get_state().solvers_state[scene.solvers.index(scene.rigid_solver)] + + +@pytest.mark.required +@pytest.mark.precision("64") +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +def test_diff_joint_limit_forward_enforcement(show_viewer): + """With `enable_joint_limit=True` and slider range=[-4,4], pushing the + cart at v=100 for 60 steps must keep |x| bounded; with the limit off + the cart drifts past 90 m (control case).""" + mjcf_path = "xml/grad/slider_limit.xml" + + # Control: limit OFF + scene, robot = _build(mjcf_path, requires_grad=False, enable_joint_limit=False) + scene.reset() + robot.set_dofs_velocity(gs.tensor([100.0], dtype=gs.tc_float)) + for _ in range(60): + scene.step() + x_off = float(_rigid_state(scene).qpos[0, 0].detach()) + assert abs(x_off) > 50.0, f"control (limit OFF) cart should drift past 50m, got x={x_off}" + + # Limit ON — should stay bounded. + scene, robot = _build(mjcf_path, requires_grad=False, enable_joint_limit=True) + scene.reset() + robot.set_dofs_velocity(gs.tensor([100.0], dtype=gs.tc_float)) + for _ in range(60): + scene.step() + x_on = float(_rigid_state(scene).qpos[0, 0].detach()) + assert abs(x_on) <= 4.5, f"limit ON should keep |x| <= 4.5 (small margin for soft constraint), got x={x_on}" + + +@pytest.mark.required +@pytest.mark.precision("64") +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("init_vel", [0.5, 5.0]) +def test_diff_joint_limit_backward_finite_no_limit_hit(show_viewer, init_vel): + """When the rollout stays well inside the joint range, the joint-limit + branch should *not* activate (`pos_delta >= 0`), and the gradient should + match the no-limit baseline almost byte-exactly. This pins the + "limit-on but inactive" path through the constraint solver.""" + mjcf_path = "xml/grad/slider_limit.xml" + N_STEPS = 1 # short — cart doesn't reach limit + + grads = {} + for limit in (False, True): + scene, robot = _build(mjcf_path, requires_grad=True, enable_joint_limit=limit) + scene.reset() + v = gs.tensor([init_vel], dtype=gs.tc_float, requires_grad=True) + robot.set_dofs_velocity(v) + for _ in range(N_STEPS): + scene.step() + loss = (_rigid_state(scene).qpos[0, 0]) ** 2 + loss.backward() + assert v.grad is not None, f"limit={limit}: v.grad is None" + g = float(v.grad[0]) + assert math.isfinite(g), f"limit={limit}: gradient is not finite ({g})" + grads[limit] = g + + # Limit-inactive case should match the no-limit baseline tightly — the + # constraint branch only runs `n_constraints += 0`, so the autograd tape + # should be identical up to floating-point. + assert_allclose(grads[True], grads[False], rtol=1e-6, atol=1e-9) + + +@pytest.mark.required +@pytest.mark.precision("64") +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +def test_diff_joint_limit_backward_fd_one_step(show_viewer): + """FD vs analytical gradient when the cart starts well inside the limit + and takes a single step — verifies the constraint-solver-inclusive + forward+backward chain still satisfies central FD.""" + mjcf_path = "xml/grad/slider_limit.xml" + init_vel = 2.0 + eps = 1e-5 + + # Analytical + scene_ana, robot_ana = _build(mjcf_path, requires_grad=True, enable_joint_limit=True) + scene_ana.reset() + v = gs.tensor([init_vel], dtype=gs.tc_float, requires_grad=True) + robot_ana.set_dofs_velocity(v) + scene_ana.step() + loss = (_rigid_state(scene_ana).qpos[0, 0]) ** 2 + loss.backward() + ana = float(v.grad[0]) + + # FD + scene_fd, robot_fd = _build(mjcf_path, requires_grad=False, enable_joint_limit=True) + + def loss_at(val: float) -> float: + scene_fd.reset() + robot_fd.set_dofs_velocity(gs.tensor([val], dtype=gs.tc_float)) + scene_fd.step() + return float((_rigid_state(scene_fd).qpos[0, 0]) ** 2) + + fd = (loss_at(init_vel + eps) - loss_at(init_vel - eps)) / (2 * eps) + + assert_allclose(ana, fd, rtol=1e-3, atol=1e-6) + + +# (init_vel, n_steps) cases where the cart actually crosses |x|=4 during the +# rollout. Each case engages the constraint solver during the integration — +# they cover the `M^{-1} J^T λ` correction path that the unconstrained +# `kernel_manual_compute_qacc_bw` could not produce. Resolved 2026-05-25 by +# wiring `constraint_solver.backward` + `kernel_manual_add_joint_limit_constraints_bw` +# into `substep_pre_coupling_grad`. +_FD_ACTIVE_CASES = [ + (500.0, 1), + (200.0, 2), + (100.0, 5), + (50.0, 10), +] + + +@pytest.mark.required +@pytest.mark.precision("64") +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("init_vel,n_steps", _FD_ACTIVE_CASES) +def test_diff_joint_limit_backward_fd_active(show_viewer, init_vel, n_steps): + """FD vs analytical when the cart actually crosses the joint limit during + the rollout. Exercises the constrained backward path + (`constraint_solver.backward` + `kernel_manual_add_joint_limit_constraints_bw`) + on cases where the unconstrained IFT alone would drop the `M^{-1} J^T λ` + correction and disagree with FD (often sign-flipped). Snapshot of the + expected gradients (FP64 CPU, FD with eps=1e-4) at fix time: + + init_vel n_steps x_final v.grad + 500 1 +4.464 +5.51e-2 + 200 2 +4.203 +9.46e-2 + 100 5 +3.892 -1.60e-1 + 50 10 +3.606 -1.16e+0 + """ + mjcf_path = "xml/grad/slider_limit.xml" + eps = 1e-4 + + # Analytical + scene_ana, robot_ana = _build(mjcf_path, requires_grad=True, enable_joint_limit=True) + scene_ana.reset() + v = gs.tensor([init_vel], dtype=gs.tc_float, requires_grad=True) + robot_ana.set_dofs_velocity(v) + for _ in range(n_steps): + scene_ana.step() + x_final = float(_rigid_state(scene_ana).qpos[0, 0].detach()) + # Setup sanity: the cart must have entered the limit band, otherwise this + # case wouldn't actually exercise the constraint correction path. + assert abs(x_final) > 3.5, ( + f"setup error: init_vel={init_vel}, n_steps={n_steps} did not bring " + f"the cart near the limit (x_final={x_final}); pick a larger v0 or " + f"more steps." + ) + loss = (_rigid_state(scene_ana).qpos[0, 0]) ** 2 + loss.backward() + ana = float(v.grad[0]) + + # FD + scene_fd, robot_fd = _build(mjcf_path, requires_grad=False, enable_joint_limit=True) + + def loss_at(val: float) -> float: + scene_fd.reset() + robot_fd.set_dofs_velocity(gs.tensor([val], dtype=gs.tc_float)) + for _ in range(n_steps): + scene_fd.step() + return float((_rigid_state(scene_fd).qpos[0, 0]) ** 2) + + fd = (loss_at(init_vel + eps) - loss_at(init_vel - eps)) / (2 * eps) + + assert_allclose(ana, fd, rtol=1e-3, atol=1e-6) + + +# Per-step force horizons that drive the cart into the slider limit through +# `control_dofs_force`. Constant +500 N over `n_steps` accelerates the +# unit-mass cart past |x|=4 within ~10 substep-groups at dt=1/60, substeps=4 +# (default solref); shorter horizons leave the cart inside the band and +# don't exercise the constraint backward, so we restrict to multi-step +# active cases. n_steps=10 probes whether the per-step `force.grad` for +# early-horizon steps leaks a wrong gradient when the constrained backward +# chain (`constraint_solver.backward` + manual joint-limit BW + +# `fwd_dynamics_without_qacc.grad` accumulation) runs across many substeps. +_FD_FORCE_CASES = [10] + + +@pytest.mark.required +@pytest.mark.precision("64") +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("n_steps", _FD_FORCE_CASES) +def test_diff_joint_limit_backward_fd_per_step_force(show_viewer, n_steps): + """Central-FD vs analytical gradient on a per-step `control_dofs_force` + time series that drives the cart into the slider limit during the + rollout. Multi-step variant of `test_diff_joint_limit_backward_fd_active`. + """ + mjcf_path = "xml/grad/slider_limit.xml" + eps = 1e-2 + force_value = 500.0 + init_force = np.full((n_steps, 1), force_value, dtype=np.float64) + + # Analytical + scene_ana, robot_ana = _build(mjcf_path, requires_grad=True, enable_joint_limit=True) + scene_ana.reset() + forces = [gs.tensor(init_force[t], dtype=gs.tc_float, requires_grad=True) for t in range(n_steps)] + for t in range(n_steps): + robot_ana.control_dofs_force(forces[t]) + scene_ana.step() + x_final = float(_rigid_state(scene_ana).qpos[0, 0].detach()) + # Setup sanity: the cart must have entered the limit band, otherwise this + # case wouldn't actually exercise the multi-step constraint backward. + assert abs(x_final) > 3.5, ( + f"setup error: n_steps={n_steps} at force={force_value} did not bring " + f"the cart near the limit (x_final={x_final}); pick a larger force or " + f"more steps." + ) + loss = (_rigid_state(scene_ana).qpos[0, 0]) ** 2 + loss.backward() + for t, f in enumerate(forces): + assert f.grad is not None, f"step {t}: force.grad is None" + ana = np.array([float(f.grad[0]) for f in forces]) + + # FD per-step + scene_fd, robot_fd = _build(mjcf_path, requires_grad=False, enable_joint_limit=True) + + def loss_at(perturbed: np.ndarray) -> float: + scene_fd.reset() + for t in range(n_steps): + robot_fd.control_dofs_force(gs.tensor(perturbed[t], dtype=gs.tc_float)) + scene_fd.step() + return float((_rigid_state(scene_fd).qpos[0, 0]) ** 2) + + fd = np.zeros(n_steps) + for t in range(n_steps): + plus = init_force.copy() + plus[t, 0] += eps + minus = init_force.copy() + minus[t, 0] -= eps + fd[t] = (loss_at(plus) - loss_at(minus)) / (2 * eps) + + # Per-step comparison so the failure message identifies the offending step. + for t in range(n_steps): + assert_allclose( + ana[t], + fd[t], + rtol=1e-3, + atol=1e-4, + err_msg=( + f"per-step force.grad mismatch at t={t}/{n_steps} " + f"(ana={ana[t]:+.4e}, fd={fd[t]:+.4e}); full ana={ana}, fd={fd}" + ), + ) + + +def _build_cartpole(mjcf_path: str, *, requires_grad: bool): + """Build a multi-body cart+pole scene with gravity + slider limit on.""" + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=1.0 / 60.0, + substeps=4, + gravity=(0.0, 0.0, -9.81), + requires_grad=requires_grad, + ), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + enable_self_collision=False, + enable_joint_limit=True, + disable_constraint=False, + use_hibernation=False, + use_contact_island=False, + ), + show_viewer=False, + ) + robot = scene.add_entity(gs.morphs.MJCF(file=mjcf_path)) + scene.build(n_envs=0) + return scene, robot + + +@pytest.mark.required +@pytest.mark.precision("64") +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("n_steps", [15]) +def test_diff_joint_limit_backward_fd_per_step_force_cartpole(show_viewer, n_steps): + """Multi-body variant of `test_diff_joint_limit_backward_fd_per_step_force`. + + Same per-step `control_dofs_force` setup but on the *cart+pole* MJCF + (slider [-4, 4] + free-rotating hinge with gravity). cart DOF is + actuated at `force_value`, pole DOF takes 0 force. Loss is `cart_x^2` + at the terminal step. Both cart_force.grad AND pole_force.grad are + compared to central FD — pole_force.grad is *non-zero* because a + hinge torque accelerates the pole, the pole's swing exerts reactive + horizontal force on the cart via the hinge, which moves cart_x. + """ + mjcf_path = "xml/cartpole.xml" + eps = 1e-2 + # cart+pole effective mass ≈ 11 (cart 1 + pole 10 horizontal-locked at + # hanging), so cart force needs to be larger than the cart-only test + # to reach the limit within `n_steps`. + force_value = 2000.0 + # Force shape per step: (n_dofs,) = (cart_f, pole_f). pole stays at 0. + init_force = np.zeros((n_steps, 2), dtype=np.float64) + init_force[:, 0] = force_value + + # Initial state: cart at x=0, pole hanging down at theta=-pi (same as + # `CartPoleSwingUpEnv._init_qpos`). Deterministic; same in ana / FD. + init_qpos = [0.0, -math.pi] + + # Analytical + scene_ana, robot_ana = _build_cartpole(mjcf_path, requires_grad=True) + scene_ana.reset() + robot_ana.set_dofs_position(gs.tensor(init_qpos, dtype=gs.tc_float)) + forces = [gs.tensor(init_force[t], dtype=gs.tc_float, requires_grad=True) for t in range(n_steps)] + for t in range(n_steps): + robot_ana.control_dofs_force(forces[t]) + scene_ana.step() + x_final = float(_rigid_state(scene_ana).qpos[0, 0].detach()) + assert abs(x_final) > 3.5, ( + f"setup error: cart+pole at n_steps={n_steps}, force={force_value} " + f"did not bring the cart near the limit (x_final={x_final}); pick a " + f"larger force or more steps." + ) + loss = (_rigid_state(scene_ana).qpos[0, 0]) ** 2 + loss.backward() + for t, f in enumerate(forces): + assert f.grad is not None, f"step {t}: force.grad is None" + # cart-force grad per step (slot 0); slot 1 is pole-force grad, must be 0. + ana_cart = np.array([float(f.grad[0]) for f in forces]) + ana_pole = np.array([float(f.grad[1]) for f in forces]) + + # FD per-step on the cart-force slot only. + scene_fd, robot_fd = _build_cartpole(mjcf_path, requires_grad=False) + + def loss_at(perturbed: np.ndarray) -> float: + scene_fd.reset() + robot_fd.set_dofs_position(gs.tensor(init_qpos, dtype=gs.tc_float)) + for t in range(n_steps): + robot_fd.control_dofs_force(gs.tensor(perturbed[t], dtype=gs.tc_float)) + scene_fd.step() + return float((_rigid_state(scene_fd).qpos[0, 0]) ** 2) + + fd_cart = np.zeros(n_steps) + fd_pole = np.zeros(n_steps) + for t in range(n_steps): + plus = init_force.copy() + plus[t, 0] += eps + minus = init_force.copy() + minus[t, 0] -= eps + fd_cart[t] = (loss_at(plus) - loss_at(minus)) / (2 * eps) + + plus = init_force.copy() + plus[t, 1] += eps + minus = init_force.copy() + minus[t, 1] -= eps + fd_pole[t] = (loss_at(plus) - loss_at(minus)) / (2 * eps) + + # Cart-force grad — straight chain from action to cart_x. + for t in range(n_steps): + assert_allclose( + ana_cart[t], + fd_cart[t], + rtol=1e-3, + atol=1e-4, + err_msg=( + f"cart-pole cart_force.grad mismatch at t={t}/{n_steps} " + f"(ana={ana_cart[t]:+.4e}, fd={fd_cart[t]:+.4e}); " + f"full ana={ana_cart}, fd={fd_cart}" + ), + ) + # Pole-force grad — hinge torque chain: pole_force -> pole_angle -> + # pole COM horizontal accel -> reactive force on cart via hinge -> + # cart_x. Non-zero, must still match FD step-by-step. + for t in range(n_steps): + assert_allclose( + ana_pole[t], + fd_pole[t], + rtol=1e-3, + atol=1e-4, + err_msg=( + f"cart-pole pole_force.grad mismatch at t={t}/{n_steps} " + f"(ana={ana_pole[t]:+.4e}, fd={fd_pole[t]:+.4e}); " + f"full ana_pole={ana_pole}, fd_pole={fd_pole}" + ), + ) + + +def _build_hopper(mjcf_path: str, *, requires_grad: bool): + """Build the hopper collision-free with joint limits ON and gravity off. + + Collision is off so the joint-limit constraint is the only constraint in + play (no foot-ground contact); gravity is off so the base doesn't drift, + keeping the rollout focused on driving a leg joint into its range limit. + """ + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=1.0 / 60.0, + substeps=4, + gravity=(0.0, 0.0, 0.0), + requires_grad=requires_grad, + ), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + enable_self_collision=False, + enable_joint_limit=True, + disable_constraint=False, + use_hibernation=False, + use_contact_island=False, + ), + show_viewer=False, + ) + robot = scene.add_entity(gs.morphs.MJCF(file=mjcf_path)) + scene.build(n_envs=0) + return scene, robot + + +@pytest.mark.required +@pytest.mark.precision("64") +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("n_steps", [10]) +def test_diff_joint_limit_backward_fd_per_step_force_hopper(show_viewer, n_steps): + """Joint-limit backward FD check on the hopper — combines the multi-joint + planar base (slide+slide+hinge on the torso) with an *active* joint-limit + constraint, both collision-free. + + A constant torque on the foot joint drives it past its `[-0.785, 0.785]` + range during the rollout, engaging the joint-limit inequality constraint. + The loss is on the foot LINK world position (so the gradient flows through + the full FK chain — exercising `kernel_manual_forward_kinematics_bw`'s multi-joint + reverse for the base — as well as the constraint backward). Every DOF's + per-step `control_dofs_force.grad` is compared to central FD: the foot DOF + is the forced + limited one; the other DOFs are reached only through the + articulated coupling, so their FD sensitivity is non-trivial too. + """ + mjcf_path = "xml/hopper.xml" + n_dofs = 6 # rootx, rootz, rooty, thigh, leg, foot + foot_dof = 5 + eps = 1e-2 + force_value = 200.0 + init_force = np.zeros((n_steps, n_dofs), dtype=np.float64) + init_force[:, foot_dof] = force_value + + def _links_pos_sq_loss(scene): + lp = _rigid_state(scene).links_pos + return (lp.reshape(-1) ** 2).sum() + + # Analytical + scene_ana, robot_ana = _build_hopper(mjcf_path, requires_grad=True) + scene_ana.reset() + forces = [gs.tensor(init_force[t], dtype=gs.tc_float, requires_grad=True) for t in range(n_steps)] + for t in range(n_steps): + robot_ana.control_dofs_force(forces[t]) + scene_ana.step() + foot_q = float(_rigid_state(scene_ana).qpos[0, foot_dof].detach()) + # Setup sanity: the foot must have entered its limit band, else the + # constraint backward isn't exercised. + assert abs(foot_q) > 0.7, ( + f"setup error: n_steps={n_steps} at foot force={force_value} did not " + f"drive the foot joint near its 0.785 limit (foot_q={foot_q}); pick a " + f"larger force or more steps." + ) + loss = _links_pos_sq_loss(scene_ana) + loss.backward() + for t, f in enumerate(forces): + assert f.grad is not None, f"step {t}: force.grad is None" + ana = np.array([[float(f.grad[d]) for d in range(n_dofs)] for f in forces]) # (n_steps, n_dofs) + + # FD per-step, per-dof + scene_fd, robot_fd = _build_hopper(mjcf_path, requires_grad=False) + + def loss_at(perturbed: np.ndarray) -> float: + scene_fd.reset() + for t in range(n_steps): + robot_fd.control_dofs_force(gs.tensor(perturbed[t], dtype=gs.tc_float)) + scene_fd.step() + return float(_links_pos_sq_loss(scene_fd).detach()) + + fd = np.zeros((n_steps, n_dofs)) + for t in range(n_steps): + for d in range(n_dofs): + plus = init_force.copy() + plus[t, d] += eps + minus = init_force.copy() + minus[t, d] -= eps + fd[t, d] = (loss_at(plus) - loss_at(minus)) / (2 * eps) + + for t in range(n_steps): + for d in range(n_dofs): + assert_allclose( + ana[t, d], + fd[t, d], + rtol=1e-3, + atol=1e-4, + err_msg=( + f"hopper force.grad mismatch at t={t}/{n_steps}, dof={d} " + f"(ana={ana[t, d]:+.4e}, fd={fd[t, d]:+.4e})\nfull ana=\n{ana}\nfull fd=\n{fd}" + ), + ) + + +# =========================================================================== +# Collision / diff-GJK contact FD (enable_collision=True -> constraints ON) +# =========================================================================== + + +def _build_box_box(*, requires_grad: bool): + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=0.01, + substeps=2, + gravity=(0.0, 0.0, -9.81), + requires_grad=requires_grad, + ), + rigid_options=gs.options.RigidOptions( + enable_collision=True, + enable_self_collision=False, + enable_joint_limit=False, + disable_constraint=False, + use_hibernation=False, + use_contact_island=False, + box_box_detection=False, # general convex-convex GJK (differentiable) path + ), + show_viewer=False, + ) + scene.add_entity(gs.morphs.Box(size=(2.0, 2.0, 0.2), pos=(0.0, 0.0, 0.1), fixed=True)) + box = scene.add_entity(gs.morphs.Box(size=(0.4, 0.4, 0.4), pos=(0.0, 0.0, 0.4))) + scene.build(n_envs=0) + return scene, box + + +def _build_plane_convex(shape: str, *, requires_grad: bool): + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=0.01, + substeps=2, + gravity=(0.0, 0.0, -9.81), + requires_grad=requires_grad, + ), + rigid_options=gs.options.RigidOptions( + enable_collision=True, + enable_self_collision=False, + enable_joint_limit=False, + disable_constraint=False, + use_hibernation=False, + use_contact_island=False, + box_box_detection=False, + ), + show_viewer=False, + ) + scene.add_entity(gs.morphs.Plane()) + if shape == "box": + obj = scene.add_entity(gs.morphs.Box(size=(0.4, 0.4, 0.4), pos=(0.0, 0.0, 0.3))) + elif shape == "sphere": + obj = scene.add_entity(gs.morphs.Sphere(radius=0.2, pos=(0.0, 0.0, 0.3))) + elif shape == "capsule": + obj = scene.add_entity(gs.morphs.MJCF(file="xml/grad/capsule.xml", align=False)) + else: + raise ValueError(shape) + scene.build(n_envs=0) + return scene, obj + + +def _n_contacts(scene) -> int: + return int(scene.rigid_solver.collider._collider_state.n_contacts.to_numpy()[0]) + + +def _settle(scene, obj, n_settle: int): + zero = gs.tensor([0.0] * 6, dtype=gs.tc_float) + for _ in range(n_settle): + obj.control_dofs_force(zero) + scene.step() + + +def _run_fd_per_step_force(build_fn, rest_dofs, *, base_force, n_settle, n_steps, fd_dofs, eps, rtol, atol): + """Analytical-vs-central-FD driver for a free convex pressed into a fixed + collider by a per-step downward force. + + The force is purely downward/centered: a lateral or torque component would + tip the body and change the contact manifold (breaking contact-pair + preservation), so only the load-bearing z DOF is FD-checked. Its gradient + still runs through contact_pos / normal / penetration inside the constraint + reverse and the differentiable narrow phase. The FD scene runs in + `requires_grad=True` (forward only) so it produces the *same* contact + manifold as the analytical scene; a production-mode scene would take a + different (non-diff) narrow-phase path with a different contact set. + """ + init_force = np.broadcast_to(base_force, (n_steps, 6)).copy() + + # --- analytical --- + scene_ana, obj_ana = build_fn(requires_grad=True) + scene_ana.reset() + obj_ana.set_dofs_position(gs.tensor(rest_dofs, dtype=gs.tc_float).sceneless()) + _settle(scene_ana, obj_ana, n_settle) + nc = _n_contacts(scene_ana) + assert nc > 0, f"setup error: not in contact after settle (n_contacts={nc})" + + forces = [gs.tensor(init_force[t], dtype=gs.tc_float, requires_grad=True) for t in range(n_steps)] + for t in range(n_steps): + obj_ana.control_dofs_force(forces[t]) + scene_ana.step() + assert _n_contacts(scene_ana) == nc, "contact set changed during grad window — FD invalid" + loss = (_rigid_state(scene_ana).qpos[0, :3] ** 2).sum() + scene_ana.backward(loss) + ana = np.array([[float(f.grad[d]) for d in range(6)] for f in forces]) # (N, 6) + + # --- central FD, contact set preserved --- + scene_fd, obj_fd = build_fn(requires_grad=True) + + def loss_at(perturbed: np.ndarray) -> float: + scene_fd.reset() + obj_fd.set_dofs_position(gs.tensor(rest_dofs, dtype=gs.tc_float).sceneless()) + _settle(scene_fd, obj_fd, n_settle) + for t in range(n_steps): + obj_fd.control_dofs_force(gs.tensor(perturbed[t], dtype=gs.tc_float)) + scene_fd.step() + assert _n_contacts(scene_fd) == nc, "contact set changed under FD perturbation" + return float((_rigid_state(scene_fd).qpos[0, :3] ** 2).sum().detach()) + + fd = np.full((n_steps, 6), np.nan) + for t in range(n_steps): + for d in fd_dofs: + plus = init_force.copy() + plus[t, d] += eps + minus = init_force.copy() + minus[t, d] -= eps + fd[t, d] = (loss_at(plus) - loss_at(minus)) / (2 * eps) + + # Contact gradients are small (stiff contact barely moves), so the band is + # absolute-dominated; rtol pins the load-bearing z entry. + for t in range(n_steps): + for d in fd_dofs: + assert_allclose( + ana[t, d], + fd[t, d], + rtol=rtol, + atol=atol, + err_msg=f"contact force.grad mismatch at t={t}/{n_steps}, dof={d}\nana=\n{ana}\nfd=\n{fd}", + ) + + +@pytest.mark.required +@pytest.mark.precision("64") +@pytest.mark.parametrize( + "backend", + [ + gs.cpu, + pytest.param( + gs.gpu, + marks=pytest.mark.skip( + reason="General convex-convex (box-box) differentiable contact is not supported on GPU: the " + "GPU split (multicontact) narrow phase drops contacts under requires_grad=True (n_contacts=0). " + "Only plane-convex is GPU-differentiable (see test_diff_contact_fd_plane_convex). " + "Revisit when the split path's diff-contact handling is fixed." + ), + ), + ], +) +def test_diff_contact_fd_per_step_force(show_viewer): + # Box rests on the ground top (z=0.2) at center z=0.40; settle to a stable + # multi-contact manifold, then a short grad window with a per-step push. + _run_fd_per_step_force( + _build_box_box, + [0.0, 0.0, 0.40, 0.0, 0.0, 0.0], # freejoint 6 DOFs: xyz + rotvec(=identity) + base_force=np.array([0.0, 0.0, -8.0, 0.0, 0.0, 0.0], dtype=np.float64), + n_settle=12, + n_steps=2, + fd_dofs=(2,), + eps=1e-2, + rtol=2e-3, + atol=1e-10, + ) + + +# rest z so the body's lowest point sits on the plane (z=0): box/sphere half +# extent 0.2; capsule radius 0.1 + half_length 0.2 = 0.3 (upright). +_PLANE_REST = { + "box": [0.0, 0.0, 0.20, 0.0, 0.0, 0.0], + "sphere": [0.0, 0.0, 0.20, 0.0, 0.0, 0.0], + "capsule": [0.0, 0.0, 0.30, 0.0, 0.0, 0.0], +} + + +@pytest.mark.required +@pytest.mark.precision("64") +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("shape", ["box", "sphere", "capsule"]) +def test_diff_contact_fd_plane_convex(shape, show_viewer): + # Plane (fixed) + free convex. The analytic plane contact is reconstructed + # differentiably via `func_differentiable_plane_contact` (stored convex + # support core + radius), so the same FD chain as box-box applies. + _run_fd_per_step_force( + lambda *, requires_grad: _build_plane_convex(shape, requires_grad=requires_grad), + _PLANE_REST[shape], + base_force=np.array([0.0, 0.0, -8.0, 0.0, 0.0, 0.0], dtype=np.float64), + n_settle=12, + n_steps=2, + fd_dofs=(2,), + eps=1e-2, + rtol=2e-3, + atol=1e-10, + ) + + +# =========================================================================== +# Low-level contact-detection + constraint-solver backward FD +# =========================================================================== + + +@pytest.mark.required +@pytest.mark.precision("64") +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +def test_diff_contact(): + RTOL = 1e-4 + + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=0.01, + # Turn on differentiable mode + requires_grad=True, + ), + show_viewer=False, + ) + + box_size = 0.25 + box_spacing = box_size + vec_one = np.array([1.0, 1.0, 1.0]) + box_pos_offset = (0.0, 0.0, 0.0) + 0.5 * box_size * vec_one + + box0 = scene.add_entity( + gs.morphs.Box(size=box_size * vec_one, pos=box_pos_offset), + ) + box1 = scene.add_entity( + gs.morphs.Box(size=box_size * vec_one, pos=box_pos_offset + 0.8 * box_spacing * np.array([0, 0, 1])), + ) + scene.build() + solver = scene.sim.rigid_solver + collider = solver.collider + + # Set up initial configuration + x_ang, y_ang, z_ang = 3.0, 3.0, 3.0 + box1.set_quat(R_to_quat(gs.euler_to_R([np.deg2rad(x_ang), np.deg2rad(y_ang), np.deg2rad(z_ang)]))) + + box0_init_pos = box0.get_pos().clone() + box1_init_pos = box1.get_pos().clone() + box0_init_quat = box0.get_quat().clone() + box1_init_quat = box1.get_quat().clone() + + ### Compute the initial loss and compute gradients using differentiable contact detection + # Detect contact + collider.detection() + + # Get contact outputs and their grads + contacts = collider.get_contacts(as_tensor=True, to_torch=True, keep_batch_dim=True) + normal = contacts["normal"].requires_grad_() + position = contacts["position"].requires_grad_() + penetration = contacts["penetration"].requires_grad_() + + loss = ((normal * position).sum(dim=-1) * penetration).sum() + dL_dnormal = torch.autograd.grad(loss, normal, retain_graph=True)[0] + dL_dposition = torch.autograd.grad(loss, position, retain_graph=True)[0] + dL_dpenetration = torch.autograd.grad(loss, penetration)[0] + + # Compute analytical gradients of the geoms position and quaternion + collider.backward(dL_dposition, dL_dnormal, dL_dpenetration) + dL_dpos = qd_to_torch(solver.geoms_state.pos.grad) + dL_dquat = qd_to_torch(solver.geoms_state.quat.grad) + + ### Compute directional derivatives along random directions + FD_EPS = 1e-5 + TRIALS = 100 + + def compute_dL_error(dL_dx, x_type): + dL_error_rel = 0.0 + + box0_input_pos = box0_init_pos + box1_input_pos = box1_init_pos + box0_input_quat = box0_init_quat + box1_input_quat = box1_init_quat + + for _ in range(TRIALS): + rand_dx = torch.randn_like(dL_dx) + rand_dx = torch.nn.functional.normalize(rand_dx, dim=-1) + + dL = (rand_dx * dL_dx).sum() + + lossPs = [] + for sign in (1, -1): + # Compute query point + if x_type == "pos": + box0_input_pos = box0_init_pos + sign * rand_dx[0, 0] * FD_EPS + box1_input_pos = box1_init_pos + sign * rand_dx[1, 0] * FD_EPS + else: + # FIXME: The quaternion should be normalized + box0_input_quat = box0_init_quat + sign * rand_dx[0, 0] * FD_EPS + box1_input_quat = box1_init_quat + sign * rand_dx[1, 0] * FD_EPS + + # Update box positions + box0.set_pos(box0_input_pos) + box1.set_pos(box1_input_pos) + box0.set_quat(box0_input_quat) + box1.set_quat(box1_input_quat) + + # Re-detect contact. + # We need to manually reset the contact counter as we are not running the whole sim step. + collider._collider_state.n_contacts.fill(0) + collider.detection() + contacts = collider.get_contacts(as_tensor=True, to_torch=True, keep_batch_dim=True) + normal, position, penetration = contacts["normal"], contacts["position"], contacts["penetration"] + + # Compute loss + loss = ((normal * position).sum(dim=-1) * penetration).sum() + lossPs.append(loss) + + dL_fd = (lossPs[0] - lossPs[1]) / (2 * FD_EPS) + dL_error_rel += (dL - dL_fd).abs() / max(dL.abs(), dL_fd.abs(), gs.EPS) + + dL_error_rel /= TRIALS + return dL_error_rel + + dL_dpos_error_rel = compute_dL_error(dL_dpos, "pos") + assert_allclose(dL_dpos_error_rel, 0.0, atol=RTOL) + dL_dquat_error_rel = compute_dL_error(dL_dquat, "quat") + assert_allclose(dL_dquat_error_rel, 0.0, atol=RTOL) + + +# We need to use 64-bit precision for this test because we need to use sufficiently small perturbation to get reliable +# gradient estimates through finite difference method. This small perturbation is not supported by 32-bit precision in +# stable way. +@pytest.mark.required +@pytest.mark.precision("64") +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +def test_diff_solver(monkeypatch): + from genesis.engine.solvers.rigid.constraint.solver import func_solve_init, func_solve_body + from genesis.engine.solvers.rigid.rigid_solver import kernel_step_1 + + RTOL = 1e-4 + + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=0.01, + requires_grad=True, + ), + rigid_options=gs.options.RigidOptions( + # We use Newton's method because it converges faster than CG, and therefore gives better gradient estimation + # when using finite difference method + constraint_solver=gs.constraint_solver.Newton, + ), + show_viewer=False, + ) + + scene.add_entity(gs.morphs.Plane(pos=(0, 0, 0))) + scene.add_entity(gs.morphs.Box(size=(1, 1, 1), pos=(10, 10, 0.49))) + franka = scene.add_entity( + gs.morphs.MJCF(file="xml/franka_emika_panda/panda.xml"), + ) + + scene.build() + rigid_solver = scene._sim.rigid_solver + constraint_solver = rigid_solver.constraint_solver + + franka.set_qpos([-1.0124, 1.5559, 1.3662, -1.6878, -1.5799, 1.7757, 1.4602, 0.04, 0.04]) + + # Monkeypatch the constraint resolve function to avoid overwriting the necessary information for computing gradients. + def constraint_solver_resolve(): + func_solve_init( + dofs_info=rigid_solver.dofs_info, + dofs_state=rigid_solver.dofs_state, + entities_info=rigid_solver.entities_info, + constraint_state=constraint_solver.constraint_state, + rigid_global_info=rigid_solver._rigid_global_info, + static_rigid_sim_config=rigid_solver._static_rigid_sim_config, + ) + func_solve_body( + entities_info=rigid_solver.entities_info, + dofs_info=rigid_solver.dofs_info, + dofs_state=rigid_solver.dofs_state, + constraint_state=constraint_solver.constraint_state, + rigid_global_info=rigid_solver._rigid_global_info, + static_rigid_sim_config=rigid_solver._static_rigid_sim_config, + _n_iterations=constraint_solver._n_iterations, + ) + + monkeypatch.setattr(constraint_solver, "resolve", constraint_solver_resolve) + + # Step once to compute constraint solver's inputs: [mass], [jac], [aref], [efc_D], [force]. We do not call the + # entire scene.step() because it will overwrite the necessary information that we need to compute the gradients. + kernel_step_1( + links_state=rigid_solver.links_state, + links_info=rigid_solver.links_info, + joints_state=rigid_solver.joints_state, + joints_info=rigid_solver.joints_info, + dofs_state=rigid_solver.dofs_state, + dofs_info=rigid_solver.dofs_info, + geoms_state=rigid_solver.geoms_state, + geoms_info=rigid_solver.geoms_info, + entities_state=rigid_solver.entities_state, + entities_info=rigid_solver.entities_info, + rigid_global_info=rigid_solver._rigid_global_info, + static_rigid_sim_config=rigid_solver._static_rigid_sim_config, + contact_island_state=constraint_solver.contact_island.contact_island_state, + is_forward_pos_updated=True, + is_forward_vel_updated=True, + is_backward=False, + ) + constraint_solver.add_equality_constraints() + rigid_solver.collider.detection() + constraint_solver.add_inequality_constraints() + constraint_solver.resolve() + + # Loss function to compute gradients using finite difference method + def compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force): + rigid_solver._rigid_global_info.mass_mat.from_numpy(input_mass) + constraint_solver.constraint_state.jac.from_numpy(input_jac) + constraint_solver.constraint_state.aref.from_numpy(input_aref) + constraint_solver.constraint_state.efc_D.from_numpy(input_efc_D) + rigid_solver.dofs_state.force.from_numpy(input_force) + + # Recompute acc_smooth from the updated input variables + updated_acc_smooth = np.linalg.solve(input_mass[..., 0], input_force[..., 0]) + rigid_solver.dofs_state.acc_smooth.from_numpy(updated_acc_smooth[..., None]) + constraint_solver.resolve() + + output_qacc = qd_to_torch(constraint_solver.qacc) + return ((output_qacc - target_qacc) ** 2).mean() + + init_input_mass = qd_to_numpy(rigid_solver._rigid_global_info.mass_mat, copy=True) + init_input_jac = qd_to_numpy(constraint_solver.constraint_state.jac, copy=True) + init_input_aref = qd_to_numpy(constraint_solver.constraint_state.aref, copy=True) + init_input_efc_D = qd_to_numpy(constraint_solver.constraint_state.efc_D, copy=True) + init_input_force = qd_to_numpy(rigid_solver.dofs_state.force, copy=True) + + # Initial output of the constraint solver + set_random_seed(0) + init_output_qacc = qd_to_torch(constraint_solver.qacc) + target_qacc = torch.from_numpy(np.random.randn(*init_output_qacc.shape)).to(device=gs.device) + target_qacc = target_qacc * init_output_qacc.abs().mean() + + # Solve the constraint solver and get the output + output_qacc = qd_to_torch(constraint_solver.qacc, copy=True).requires_grad_(True) + + # Compute loss and gradient of the output + loss = ((output_qacc - target_qacc) ** 2).mean() + dL_dqacc = tensor_to_array(torch.autograd.grad(loss, output_qacc)[0]) + + # Compute gradients of the input variables: [mass], [jac], [aref], [efc_D], [force] + constraint_solver.constraint_state.dL_dqacc.from_numpy(dL_dqacc) + constraint_solver.backward() + + # Fetch gradients of the input variables + dL_dM = qd_to_numpy(constraint_solver.constraint_state.dL_dM) + dL_djac = qd_to_numpy(constraint_solver.constraint_state.dL_djac) + dL_daref = qd_to_numpy(constraint_solver.constraint_state.dL_daref) + dL_defc_D = qd_to_numpy(constraint_solver.constraint_state.dL_defc_D) + dL_dforce = qd_to_numpy(constraint_solver.constraint_state.dL_dforce) + + ### Compute directional derivatives along random directions + FD_EPS = 1e-3 + TRIALS = 200 + + for dL_dx, x_type in ( + (dL_dforce, "force"), + (dL_daref, "aref"), + (dL_defc_D, "efc_D"), + (dL_djac, "jac"), + (dL_dM, "mass"), + ): + dL_error = 0.0 + for _ in range(TRIALS): + rand_dx = np.random.randn(*dL_dx.shape) + rand_dx = rand_dx / max( + np.linalg.norm(rand_dx, axis=0 if x_type in ("force", "aref", "efc_D") else (0, 1)), gs.EPS + ) + if x_type == "mass": + # Make rand_dx symmetric + rand_dx = (rand_dx + np.moveaxis(rand_dx, 0, 1)) * 0.5 + + dL = (rand_dx * dL_dx).sum() + + input_force = init_input_force + input_aref = init_input_aref + input_efc_D = init_input_efc_D + input_jac = init_input_jac + input_mass = init_input_mass + + # 1 * eps + if x_type == "force": + input_force = init_input_force + rand_dx * FD_EPS + elif x_type == "aref": + input_aref = init_input_aref + rand_dx * FD_EPS + elif x_type == "efc_D": + input_efc_D = init_input_efc_D + rand_dx * FD_EPS + elif x_type == "jac": + input_jac = init_input_jac + rand_dx * FD_EPS + elif x_type == "mass": + input_mass = init_input_mass + rand_dx * FD_EPS + lossP1 = compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force) + + # -1 * eps + if x_type == "force": + input_force = init_input_force - rand_dx * FD_EPS + elif x_type == "aref": + input_aref = init_input_aref - rand_dx * FD_EPS + elif x_type == "efc_D": + input_efc_D = init_input_efc_D - rand_dx * FD_EPS + elif x_type == "jac": + input_jac = init_input_jac - rand_dx * FD_EPS + elif x_type == "mass": + input_mass = init_input_mass - rand_dx * FD_EPS + + lossP2 = compute_loss(input_mass, input_jac, input_aref, input_efc_D, input_force) + dL_fd = (lossP1 - lossP2) / (2 * FD_EPS) + + dL_error += (dL - dL_fd).abs() / max(abs(dL), abs(dL_fd), gs.EPS) + + dL_error /= TRIALS + assert_allclose(dL_error, 0.0, atol=RTOL) diff --git a/tests/test_grad_mpm.py b/tests/test_grad_mpm.py new file mode 100644 index 0000000000..6535a26967 --- /dev/null +++ b/tests/test_grad_mpm.py @@ -0,0 +1,101 @@ +"""Differentiable MPM + Tool coupling — gradient-flow sanity. + +Separate from the rigid grad suite (test_grad_fd / test_grad_optim / +test_grad_utils): this exercises the MPM solver and a Tool (rigid stick) +pushing an elastic MPM box toward a goal, and checks that `loss.backward()` +routes non-zero gradients to the per-step control inputs (and zero to the +unused final step). It is a coupling/backward smoke test, not an FD or +optimization check. +""" + +import pytest +import torch + +import genesis as gs + + +pytestmark = [ + pytest.mark.debug(False), +] + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +def test_differentiable_push(show_viewer): + HORIZON = 10 + + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=2e-3, + substeps=10, + requires_grad=True, + ), + mpm_options=gs.options.MPMOptions( + lower_bound=(0.0, -1.0, 0.0), + upper_bound=(1.0, 1.0, 0.55), + ), + viewer_options=gs.options.ViewerOptions( + camera_pos=(2.5, -0.15, 2.42), + camera_lookat=(0.5, 0.5, 0.1), + ), + show_viewer=show_viewer, + ) + + plane = scene.add_entity( + gs.morphs.URDF( + file="urdf/plane/plane.urdf", + fixed=True, + ) + ) + stick = scene.add_entity( + morph=gs.morphs.Mesh( + file="meshes/stirrer.obj", + scale=0.6, + pos=(0.5, 0.5, 0.05), + euler=(90.0, 0.0, 0.0), + ), + material=gs.materials.Tool( + friction=8.0, + ), + ) + obj = scene.add_entity( + morph=gs.morphs.Box( + lower=(0.2, 0.1, 0.05), + upper=(0.4, 0.3, 0.15), + ), + material=gs.materials.MPM.Elastic( + rho=500, + ), + ) + scene.build(n_envs=2) + + init_pos = gs.tensor([[0.3, 0.1, 0.28], [0.3, 0.1, 0.5]], requires_grad=True) + stick.set_position(init_pos) + pos_obj_init = gs.tensor([0.3, 0.3, 0.1], requires_grad=True) + obj.set_position(pos_obj_init) + v_obj_init = gs.tensor([0.0, -1.0, 0.0], requires_grad=True) + obj.set_velocity(v_obj_init) + goal = gs.tensor([0.5, 0.8, 0.05]) + + loss = 0.0 + v_list = [] + for i in range(HORIZON): + v_i = gs.tensor([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0]], requires_grad=True) + stick.set_velocity(vel=v_i) + v_list.append(v_i) + + scene.step() + + if i == HORIZON // 2: + mpm_particles = scene.get_state().solvers_state[scene.solvers.index(scene.mpm_solver)] + loss += torch.pow(mpm_particles.pos[mpm_particles.active == 1] - goal, 2).sum() + + if i == HORIZON - 2: + state = obj.get_state() + loss += torch.pow(state.pos - goal, 2).sum() + loss.backward() + + # TODO: It would be great to compare the gradient to its analytical or numerical value. + for v_i in v_list[:-1]: + assert (v_i.grad.abs() > gs.EPS).any() + assert (v_list[-1].grad.abs() < gs.EPS).all() diff --git a/tests/test_grad_optim.py b/tests/test_grad_optim.py new file mode 100644 index 0000000000..7ae6059b46 --- /dev/null +++ b/tests/test_grad_optim.py @@ -0,0 +1,430 @@ +"""Integration-level optimization convergence for the differentiable rigid solver. + +Within the differentiable-rigid test suite, this is the only *end-to-end* check. +The others verify the gradient is **locally** correct (FD-vs-analytical at a +point); this one asks whether that gradient is **useful** — does plain Adam, +driven by the diff-mode backward over a multi-step horizon, actually converge to +a known answer? + +Two optimization targets on the cartpole (contact-free), each recovering a +final-state target produced by an identical rollout from known inputs: + 1. `test_diff_optim_init_vel_cartpole` — Adam on the initial `dofs_velocity`. + 2. `test_diff_optim_control_force_cartpole` — Adam on per-step `control_dofs_force` +Each asserts the per-env loss (a) drops by ≥2 orders of magnitude and (b) ends +below an absolute threshold — i.e. the backward yields an informative descent +direction over the horizon, not merely a locally correct gradient. + +Parametrized over precision ∈ {fp64, fp32} and n_envs ∈ {0 (single), 4 +(batched, per-env distinct target / init)}. fp32 uses looser thresholds for its +lower precision floor. +""" + +import os +import sys +from pathlib import Path + +import numpy as np +import pytest +import torch + +import genesis as gs + +from .utils import assert_allclose + + +_PRECISION_PARAMS = [ + pytest.param("64", marks=pytest.mark.precision("64"), id="fp64"), + pytest.param("32", marks=pytest.mark.precision("32"), id="fp32"), +] + +_N_ENVS_PARAMS = [ + pytest.param(0, id="single"), + pytest.param(4, id="batched"), +] + + +def _build_scene(*, requires_grad: bool, n_envs: int = 0, substeps: int = 1): + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=0.01, + substeps=substeps, + gravity=(0.0, 0.0, -9.81), + requires_grad=requires_grad, + ), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + enable_self_collision=False, + enable_joint_limit=False, + disable_constraint=True, + use_hibernation=False, + use_contact_island=False, + ), + show_viewer=False, + ) + robot = scene.add_entity(gs.morphs.MJCF(file="xml/cartpole.xml")) + scene.build(n_envs=n_envs) + return scene, robot + + +def _rigid_state(scene): + state = scene.get_state() + return state.solvers_state[scene.solvers.index(scene.rigid_solver)] + + +def _rollout(scene, robot, init_vel, n_steps): + """Apply `init_vel` via `set_dofs_velocity` (the @tracked setter) on a + fresh `scene.reset()`, step `n_steps` times, and return the post-rollout + (qpos, dofs_vel) tensors from the solver state.""" + scene.reset() + robot.set_dofs_velocity(init_vel) + for _ in range(n_steps): + scene.step() + s = _rigid_state(scene) + return s.qpos, s.dofs_vel + + +def _input_shape(n_dofs: int, n_envs: int): + return (n_dofs,) if n_envs == 0 else (n_envs, n_dofs) + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("precision", _PRECISION_PARAMS) +@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS) +def test_diff_optim_init_vel_cartpole(show_viewer, n_envs, precision, backend): + """Adam on `init_dofs_velocity` (2 params per env: slider + hinge) + recovers the per-env final-state target produced by an identical + rollout from a known `target_init_vel`. Verifies the diff-mode + backward yields an informative descent direction over an N=32 step + horizon, with batch independence when n_envs > 0.""" + N_STEPS = 32 + N_ITER = 200 + LR = 1e-2 + N_DOFS = 2 # cartpole: slider + hinge + B = n_envs if n_envs > 0 else 1 + + # Precision-specific tolerances. fp32 has ~7 significant digits, so a + # rollout-trained scalar loss can't get below ~1e-4 even at the + # optimum, and the rate of improvement plateaus earlier. + if precision == "64": + REL_REDUCTION = 1e-2 + ABS_THRESHOLD = 1e-4 + else: + REL_REDUCTION = 1e-1 + ABS_THRESHOLD = 1e-2 + + rng = np.random.default_rng(seed=11) + + # --- target trajectory (non-differentiable scene). Per-env distinct + # target_init_vel when n_envs > 0, so each env converges to its own + # answer. + target_init_vel_np = rng.normal(size=_input_shape(N_DOFS, n_envs)) * 0.5 + + scene_ref, robot_ref = _build_scene(requires_grad=False, n_envs=n_envs) + target_init_vel_t = gs.tensor(target_init_vel_np, dtype=gs.tc_float) + with torch.no_grad(): + target_qpos, target_vel = _rollout(scene_ref, robot_ref, target_init_vel_t, N_STEPS) + target_qpos = target_qpos.detach().clone() + target_vel = target_vel.detach().clone() + + # --- differentiable scene to optimize on. Per-env distinct init noise. + scene_opt, robot_opt = _build_scene(requires_grad=True, n_envs=n_envs) + init_offset = rng.normal(size=_input_shape(N_DOFS, n_envs)) * 0.3 + init_vel_np = target_init_vel_np + init_offset + init_vel = gs.tensor(init_vel_np, dtype=gs.tc_float, requires_grad=True) + + opt = torch.optim.Adam([init_vel], lr=LR) + loss_history = [] # list of (B,) per-env loss arrays + + for it in range(N_ITER): + opt.zero_grad(set_to_none=False) + pred_qpos, pred_vel = _rollout(scene_opt, robot_opt, init_vel, N_STEPS) + diff_pos = (pred_qpos - target_qpos).reshape(B, -1) + diff_vel = (pred_vel - target_vel).reshape(B, -1) + loss_per_env = (diff_pos**2).sum(dim=-1) + (diff_vel**2).sum(dim=-1) + loss = loss_per_env.sum() + loss_history.append(loss_per_env.detach().cpu().numpy().copy()) + loss.backward() + assert init_vel.grad is not None, f"iter {it}: init_vel.grad is None" + opt.step() + + history = np.asarray(loss_history) # (N_ITER, B) + initial = history[0] + final = history[-1] + + # Save per-env loss curves for visual inspection. + if not os.environ.get("GENESIS_DIFF_OPTIM_NO_PLOT"): + try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + backend_name = getattr(backend, "name", str(backend)) + tag = f"{backend_name}_{precision}_{'batched' if n_envs > 0 else 'single'}" + default_out = ( + Path(__file__).resolve().parent.parent / "runs" / "tmp" / f"diff_optim_init_vel_cartpole_{tag}.png" + ) + out_path = Path(os.environ.get("GENESIS_DIFF_OPTIM_PLOT_PATH", str(default_out))) + out_path.parent.mkdir(parents=True, exist_ok=True) + + cmap = plt.get_cmap("tab10") + fig, ax = plt.subplots(figsize=(7, 4)) + for b in range(B): + ax.plot(history[:, b], lw=1.2, color=cmap(b % 10), label=f"env{b}") + ax.set_yscale("log") + ax.set_xlabel("iteration") + ax.set_ylabel("loss (log scale)") + ax.set_title( + f"cartpole init_vel optim [{tag}]: " + f"init={initial.max():.2e} → final={final.max():.2e} " + f"(worst-env ratio {(final / initial).max():.2e})" + ) + ax.grid(True, which="both", alpha=0.3) + if B > 1: + ax.legend(loc="upper right", fontsize=8) + fig.tight_layout() + fig.savefig(str(out_path), dpi=120) + plt.close(fig) + print(f"\n[diff_optim] loss curve saved to {out_path}") + except ImportError: + pass + + # Per-env assertions: every env must satisfy both criteria. + rel_ratios = final / initial + worst_rel_env = int(np.argmax(rel_ratios)) + assert (rel_ratios < REL_REDUCTION).all(), ( + f"loss reduction insufficient (worst env={worst_rel_env}): " + f"initial={initial[worst_rel_env]:.3e}, final={final[worst_rel_env]:.3e}, " + f"ratio={rel_ratios[worst_rel_env]:.3e} (>= {REL_REDUCTION:.0e})" + ) + worst_abs_env = int(np.argmax(final)) + assert (final < ABS_THRESHOLD).all(), ( + f"final loss above absolute threshold (worst env={worst_abs_env}): " + f"{final[worst_abs_env]:.3e} >= {ABS_THRESHOLD:.0e}" + ) + + +def _save_loss_plot(history: np.ndarray, *, title_tag: str, plot_name: str): + """history: (N_ITER, B). Saves a per-env log-scale loss curve.""" + if os.environ.get("GENESIS_DIFF_OPTIM_NO_PLOT"): + return + try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except ImportError: + return + + default_out = Path(__file__).resolve().parent.parent / "runs" / "tmp" / f"{plot_name}.png" + out_path = Path(os.environ.get("GENESIS_DIFF_OPTIM_PLOT_PATH", str(default_out))) + out_path.parent.mkdir(parents=True, exist_ok=True) + + initial = history[0] + final = history[-1] + B = history.shape[1] + cmap = plt.get_cmap("tab10") + fig, ax = plt.subplots(figsize=(7, 4)) + for b in range(B): + ax.plot(history[:, b], lw=1.2, color=cmap(b % 10), label=f"env{b}") + ax.set_yscale("log") + ax.set_xlabel("iteration") + ax.set_ylabel("loss (log scale)") + ax.set_title( + f"{title_tag}: init={initial.max():.2e} → final={final.max():.2e} " + f"(worst-env ratio {(final / initial).max():.2e})" + ) + ax.grid(True, which="both", alpha=0.3) + if B > 1: + ax.legend(loc="upper right", fontsize=8) + fig.tight_layout() + fig.savefig(str(out_path), dpi=120) + plt.close(fig) + print(f"\n[diff_optim] loss curve saved to {out_path}") + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("precision", _PRECISION_PARAMS) +@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS) +def test_diff_optim_control_force_cartpole(show_viewer, n_envs, precision, backend): + """Adam on per-step `control_dofs_force` (N_STEPS × n_dofs params per + env) recovers the per-env final-state target.""" + N_STEPS = 32 + N_ITER = 200 + LR = 1e-2 + N_DOFS = 2 + B = n_envs if n_envs > 0 else 1 + + if precision == "64": + REL_REDUCTION = 1e-2 + ABS_THRESHOLD = 1e-4 + else: + REL_REDUCTION = 1e-1 + ABS_THRESHOLD = 1e-2 + + rng = np.random.default_rng(seed=23) + + shape_per_step = _input_shape(N_DOFS, n_envs) + target_force_np = rng.normal(size=(N_STEPS,) + shape_per_step) * 0.2 + + # --- target trajectory (non-differentiable scene) --- + scene_ref, robot_ref = _build_scene(requires_grad=False, n_envs=n_envs) + with torch.no_grad(): + scene_ref.reset() + for t in range(N_STEPS): + robot_ref.control_dofs_force(gs.tensor(target_force_np[t], dtype=gs.tc_float)) + scene_ref.step() + s = _rigid_state(scene_ref) + target_qpos = s.qpos.detach().clone() + target_vel = s.dofs_vel.detach().clone() + + # --- differentiable scene + learnable per-step force tensors --- + scene_opt, robot_opt = _build_scene(requires_grad=True, n_envs=n_envs) + init_offset = rng.normal(size=(N_STEPS,) + shape_per_step) * 0.1 + init_force_np = target_force_np + init_offset + forces = [gs.tensor(init_force_np[t], dtype=gs.tc_float, requires_grad=True) for t in range(N_STEPS)] + optimizer = torch.optim.Adam(forces, lr=LR) + + loss_history = [] + for it in range(N_ITER): + optimizer.zero_grad(set_to_none=False) + scene_opt.reset() + for t in range(N_STEPS): + robot_opt.control_dofs_force(forces[t]) + scene_opt.step() + s = _rigid_state(scene_opt) + diff_pos = (s.qpos - target_qpos).reshape(B, -1) + diff_vel = (s.dofs_vel - target_vel).reshape(B, -1) + loss_per_env = (diff_pos**2).sum(dim=-1) + (diff_vel**2).sum(dim=-1) + loss = loss_per_env.sum() + loss_history.append(loss_per_env.detach().cpu().numpy().copy()) + loss.backward() + for t, f in enumerate(forces): + assert f.grad is not None, f"iter {it} step {t}: force.grad is None" + optimizer.step() + + history = np.asarray(loss_history) # (N_ITER, B) + initial = history[0] + final = history[-1] + + backend_name = getattr(backend, "name", str(backend)) + tag = f"{backend_name}_{precision}_{'batched' if n_envs > 0 else 'single'}" + _save_loss_plot( + history, + title_tag=f"cartpole control_force optim [{tag}]", + plot_name=f"diff_optim_control_force_cartpole_{tag}", + ) + + rel_ratios = final / initial + worst_rel_env = int(np.argmax(rel_ratios)) + assert (rel_ratios < REL_REDUCTION).all(), ( + f"loss reduction insufficient (worst env={worst_rel_env}): " + f"initial={initial[worst_rel_env]:.3e}, final={final[worst_rel_env]:.3e}, " + f"ratio={rel_ratios[worst_rel_env]:.3e} (>= {REL_REDUCTION:.0e})" + ) + worst_abs_env = int(np.argmax(final)) + assert (final < ABS_THRESHOLD).all(), ( + f"final loss above absolute threshold (worst env={worst_abs_env}): " + f"{final[worst_abs_env]:.3e} >= {ABS_THRESHOLD:.0e}" + ) + + +# =========================================================================== +# Box pose recovery via Adam (rigid, full scene.step rollout) +# =========================================================================== + + +@pytest.mark.slow # ~250s +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +def test_differentiable_rigid(show_viewer): + dt = 1e-2 + horizon = 100 + substeps = 1 + goal_pos = gs.tensor([0.7, 1.0, 0.05]) + goal_quat = gs.tensor([0.3, 0.2, 0.1, 0.9]) + goal_quat = goal_quat / torch.norm(goal_quat, dim=-1, keepdim=True) + + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=dt, + substeps=substeps, + requires_grad=True, + gravity=(0, 0, -1), + ), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + enable_self_collision=False, + enable_joint_limit=False, + disable_constraint=True, + use_contact_island=False, + use_hibernation=False, + ), + viewer_options=gs.options.ViewerOptions( + camera_pos=(2.5, -0.15, 2.42), + camera_lookat=(0.5, 0.5, 0.1), + ), + show_viewer=show_viewer, + ) + + box = scene.add_entity( + gs.morphs.Box( + pos=(0, 0, 0), + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.9, 0.0, 0.0, 1.0), + ), + ) + if show_viewer: + target = scene.add_entity( + gs.morphs.Box( + pos=goal_pos, + quat=goal_quat, + size=(0.1, 0.1, 0.2), + ), + surface=gs.surfaces.Default( + color=(0.0, 0.9, 0.0, 0.5), + ), + ) + + scene.build() + + num_iter = 200 + lr = 1e-2 + + init_pos = gs.tensor([0.3, 0.1, 0.28], requires_grad=True) + init_quat = gs.tensor([1.0, 0.0, 0.0, 0.0], requires_grad=True) + optimizer = torch.optim.Adam([init_pos, init_quat], lr=lr) + + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iter, eta_min=1e-3) + + for _ in range(num_iter): + scene.reset() + + box.set_pos(init_pos) + box.set_quat(init_quat) + + loss = 0 + for _ in range(horizon): + scene.step() + if show_viewer: + target.set_pos(goal_pos) + target.set_quat(goal_quat) + + box_state = box.get_state() + box_pos = box_state.pos + box_quat = box_state.quat + loss = torch.abs(box_pos - goal_pos).sum() + torch.abs(box_quat - goal_quat).sum() + + optimizer.zero_grad() + loss.backward() # this lets gradient flow all the way back to tensor input + optimizer.step() + scheduler.step() + + with torch.no_grad(): + init_quat.data = init_quat / torch.norm(init_quat, dim=-1, keepdim=True) + + assert_allclose(loss, 0.0, atol=1e-2) diff --git a/tests/test_grad_utils.py b/tests/test_grad_utils.py new file mode 100644 index 0000000000..fe7806a06a --- /dev/null +++ b/tests/test_grad_utils.py @@ -0,0 +1,261 @@ +"""Unit tests for the `scene.backward(loss)` API. + +Within the differentiable-rigid test suite, this is the *plumbing* layer — +orthogonal to which physics is active. The other files check that the gradient +is numerically correct; this one checks that the backward *machinery* (state +snapshot/restore, gradient-tape clearing, no grad leak across chunked horizons) +is correct. + +`scene.backward(loss)` folds the snapshot → backward → restore dance +(`scene.get_state()` → `loss.backward()` → `scene.reset(snapshot)`) into a single +call; the flagship test below exercises it directly. + +The flagship test, ``test_horizon_truncation_matches_independent_scenes``, +runs three scenes in parallel: + + Scene A: single scene, 5-step horizon 1 → ``scene.backward(loss1)`` + (snapshot + backward + restore in one call) → 5-step horizon 2 + → ``scene.backward(loss2)``. Yields ``grad1_A`` and ``grad2_A``. + Scene B: same as A's horizon 1 only; ``scene.backward`` returns the + captured mid-trajectory snapshot. Yields ``grad1_B`` (compared + to ``grad1_A``) and that snapshot. + Scene C: fresh scene, starts from B's snapshot, runs 5-step horizon 2 → + ``scene.backward(loss2)``. Yields ``grad2_C`` (compared to ``grad2_A``). + +If `scene.backward(loss)` correctly (a) restores physics state, (b) clears +the gradient tape, and (c) doesn't leak grad accumulation across horizons, +then ``grad1_A == grad1_B`` and ``grad2_A == grad2_C`` exactly. +""" + +import sys + +import numpy as np +import pytest +import torch + +import genesis as gs +from genesis.utils.misc import qd_to_torch + +from .utils import assert_allclose + + +pytestmark = [ + pytest.mark.debug(False), +] + + +# Parametrization params (mirrors `test_grad_fd.py`). +_PRECISION_PARAMS = [ + pytest.param("64", marks=pytest.mark.precision("64"), id="fp64"), + pytest.param("32", marks=pytest.mark.precision("32"), id="fp32"), +] + +_N_ENVS_PARAMS = [ + pytest.param(0, id="single"), + pytest.param(4, id="batched"), +] + +_TOL = { + "64": dict(atol=1e-12, rtol=1e-10), + "32": dict(atol=1e-5, rtol=1e-4), +} + + +# J1~J5 joint topologies, loaded from the shared `xml/grad/` MJCF assets. +_TOPOLOGIES = [ + pytest.param("free", 6, id="J1_free"), + pytest.param("revolute", 1, id="J2_revolute"), + pytest.param("prismatic", 1, id="J3_prismatic"), + pytest.param("free_with_revolute", 7, id="J4_free_rev"), + pytest.param("revolute_chain3", 3, id="J5_chain3"), +] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_scene(model_name: str, n_envs: int = 0, substeps: int = 1): + """Build a diff-rigid scene with the standard "no collision / no constraint" config.""" + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=0.01, + substeps=substeps, + gravity=(0.0, 0.0, 0.0), + requires_grad=True, + ), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + enable_self_collision=False, + enable_joint_limit=False, + disable_constraint=True, + use_hibernation=False, + use_contact_island=False, + ), + show_viewer=False, + ) + robot = scene.add_entity(gs.morphs.MJCF(file=f"xml/grad/{model_name}.xml")) + scene.build(n_envs=n_envs) + return scene, robot + + +def _make_velocity(n_envs: int, n_dofs: int, seed: int) -> np.ndarray: + """Per-env-distinct velocity vector. Single env: shape (n_dofs,). Batched: (n_envs, n_dofs).""" + rng = np.random.default_rng(seed) + if n_envs == 0: + return rng.standard_normal(n_dofs) + return rng.standard_normal((n_envs, n_dofs)) + + +def _rigid_qpos_loss(scene): + """Differentiable scalar loss = sum((qpos)**2). Reads `state.qpos` via + `scene.get_state()` so the resulting tensor is a gs.Tensor whose + `.backward()` triggers `scene._backward()`.""" + state = scene.get_state() + rigid_state = state.solvers_state[scene.solvers.index(scene.rigid_solver)] + return (rigid_state.qpos**2).sum() + + +def _run_segment(scene, robot, v_tensor, n_steps: int): + """Apply `set_dofs_velocity(v_tensor)` once, then step `n_steps` times. + Returns the resulting (post-step) scalar loss.""" + robot.set_dofs_velocity(v_tensor) + for _ in range(n_steps): + scene.step() + return _rigid_qpos_loss(scene) + + +def _read_qpos(scene) -> np.ndarray: + """Read the simulator's current qpos field (detached).""" + solver = scene.rigid_solver + return qd_to_torch(solver._rigid_global_info.qpos, copy=True).cpu().numpy() + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +@pytest.mark.parametrize("precision_str", _PRECISION_PARAMS) +@pytest.mark.parametrize("substeps", [1, 4]) +@pytest.mark.parametrize("n_envs", _N_ENVS_PARAMS) +@pytest.mark.parametrize("model_name, n_dofs", _TOPOLOGIES) +def test_horizon_truncation_matches_independent_scenes(model_name, n_dofs, n_envs, substeps, precision_str): + """Two-segment trajectory in Scene A matches the same two + segments run in independent Scene B (horizon 1) and Scene C (horizon 2, + started from B's mid-trajectory snapshot via `scene.reset(state)`). + + Verifies that `scene.get_state()` + `scene.reset(state)` correctly + isolates two consecutive horizons: physics state propagates seamlessly, + but the autograd tapes are independent.""" + tol = _TOL[precision_str] + rng_v1 = _make_velocity(n_envs, n_dofs, seed=101) + rng_v2 = _make_velocity(n_envs, n_dofs, seed=202) + H = 5 + + # ----- Scene A: one scene, snapshot+reset between two horizons ----- + sceneA, robotA = _build_scene(model_name, n_envs=n_envs, substeps=substeps) + sceneA.reset() + v1A = gs.tensor(rng_v1, dtype=gs.tc_float, requires_grad=True) + loss_h1_A = _run_segment(sceneA, robotA, v1A, H) + qpos_mid_A = _read_qpos(sceneA) + # `scene.backward` snapshots the terminal state, runs the backward unroll, + # and restores that state — so horizon 2 continues seamlessly from here. + sceneA.backward(loss_h1_A) + # backward consumes the adstack / input buffer, so the step & substep + # counters reset to 0 (they index that buffer) — unlike the physics state, + # which is restored. Horizon 2 below thus records a fresh tape from 0. + assert sceneA._t == 0 and sceneA._sim._cur_substep_global == 0 + grad1_A = v1A.grad.detach().clone().cpu().numpy() + + v2A = gs.tensor(rng_v2, dtype=gs.tc_float, requires_grad=True) + loss_h2_A = _run_segment(sceneA, robotA, v2A, H) + qpos_end_A = _read_qpos(sceneA) + sceneA.backward(loss_h2_A) + grad2_A = v2A.grad.detach().clone().cpu().numpy() + + # ----- Scene B: same horizon 1 only ----- + sceneB, robotB = _build_scene(model_name, n_envs=n_envs, substeps=substeps) + sceneB.reset() + v1B = gs.tensor(rng_v1, dtype=gs.tc_float, requires_grad=True) + loss_h1_B = _run_segment(sceneB, robotB, v1B, H) + qpos_mid_B = _read_qpos(sceneB) + # `scene.backward` returns the terminal snapshot it captured; Scene C below + # loads it into a fresh scene via `reset(snapshot_B)`. + snapshot_B = sceneB.backward(loss_h1_B) + grad1_B = v1B.grad.detach().clone().cpu().numpy() + + # Sanity: A and B end at the same intermediate state and produce the same loss. + assert_allclose(qpos_mid_A, qpos_mid_B, atol=0, rtol=0) + assert_allclose(loss_h1_A.detach().cpu().item(), loss_h1_B.detach().cpu().item(), atol=0, rtol=0) + # Core assertion: horizon-1 gradient identical. + assert_allclose(grad1_A, grad1_B, **tol) + + # ----- Scene C: fresh scene, start from B's mid-trajectory snapshot ----- + sceneC, robotC = _build_scene(model_name, n_envs=n_envs, substeps=substeps) + sceneC.reset(snapshot_B) + v2C = gs.tensor(rng_v2, dtype=gs.tc_float, requires_grad=True) + loss_h2_C = _run_segment(sceneC, robotC, v2C, H) + qpos_end_C = _read_qpos(sceneC) + sceneC.backward(loss_h2_C) + grad2_C = v2C.grad.detach().clone().cpu().numpy() + + # Sanity: A and C end at the same final state and produce the same loss. + assert_allclose(qpos_end_A, qpos_end_C, atol=0, rtol=0) + assert_allclose(loss_h2_A.detach().cpu().item(), loss_h2_C.detach().cpu().item(), atol=0, rtol=0) + # Core assertion: horizon-2 gradient identical. + assert_allclose(grad2_A, grad2_C, **tol) + + +# =========================================================================== +# sim-state vs solver-state gradient parity +# =========================================================================== + + +@pytest.mark.required +@pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) +def test_diff_sim_vs_solver_state_grad_parity(show_viewer): + scene = gs.Scene( + sim_options=gs.options.SimOptions( + dt=0.01, + gravity=(0.0, 0.0, 0.0), + requires_grad=True, + ), + rigid_options=gs.options.RigidOptions( + enable_collision=False, + ), + show_viewer=show_viewer, + ) + robot = scene.add_entity( + gs.morphs.Box( + size=(0.1, 0.1, 0.1), + pos=(0, 0, 0), + ) + ) + scene.build() + + ctrl = gs.tensor(np.random.randn(robot.n_dofs), dtype=gs.tc_float, requires_grad=True) + + grads = [] + for use_sim_state in (False, True): + scene.reset() + + robot.set_dofs_velocity(ctrl) + scene.step() + + if use_sim_state: + solver_state = scene.get_state().solvers_state[scene.solvers.index(scene.rigid_solver)] + chassis_pos = solver_state.links_pos[:, 0].squeeze() + else: + chassis_pos = robot.get_state().pos.squeeze() + + loss = torch.linalg.norm(chassis_pos) + loss.backward() + grad = ctrl.grad.detach().clone() + ctrl.grad.zero_() + + # Basic sanity check + assert (grad[..., :3].abs() > gs.EPS).all() + assert (grad[..., 3:].abs() < gs.EPS).all() + + grads.append(grad) + + assert_allclose(*grads, atol=gs.EPS)