diff --git a/genesis/engine/solvers/kinematic_solver.py b/genesis/engine/solvers/kinematic_solver.py index 92ae00be0b..98efb64d99 100644 --- a/genesis/engine/solvers/kinematic_solver.py +++ b/genesis/engine/solvers/kinematic_solver.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Literal import numpy as np +import quadrants as qd import torch import genesis as gs @@ -220,6 +221,7 @@ def _build_static_config(self): # Static config with all physics disabled self._static_rigid_sim_config = array_class.RigidSimStaticConfig( backend=gs.backend, + is_qd_float_f32=(gs.qd_float == qd.f32), para_level=self.sim._para_level, requires_grad=False, use_hibernation=False, diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 041fb5a5eb..720e5e8471 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -1298,7 +1298,9 @@ def add_frictionloss_constraints( # if `serialize=True`... qd.loop_config( name="add_frictionloss_constraints", - serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL and gs.backend != gs.metal), + serialize=qd.static( + static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL and static_rigid_sim_config.backend != gs.metal + ), ) for i_b in range(_B): constraint_state.n_constraints_frictionloss[i_b] = 0 @@ -2119,7 +2121,9 @@ def func_cholesky_solve_tiled( # Performance is optimal for BLOCK_DIM = 64 BLOCK_DIM = qd.static(64) MAX_DOFS = qd.static(static_rigid_sim_config.tiled_n_dofs) - ENABLE_WARP_REDUCTION = qd.static(static_rigid_sim_config.backend == gs.cuda and gs.qd_float == qd.f32) + ENABLE_WARP_REDUCTION = qd.static( + static_rigid_sim_config.backend == gs.cuda and static_rigid_sim_config.is_qd_float_f32 + ) WARP_SIZE = qd.static(32) NUM_WARPS = qd.static(BLOCK_DIM // WARP_SIZE) diff --git a/genesis/engine/solvers/rigid/rigid_solver.py b/genesis/engine/solvers/rigid/rigid_solver.py index 42f2cf7d39..ccb26c749b 100644 --- a/genesis/engine/solvers/rigid/rigid_solver.py +++ b/genesis/engine/solvers/rigid/rigid_solver.py @@ -220,6 +220,7 @@ def _sanitize_sol_params( return sol_params +@qd.data_oriented(stable_members=True) class RigidSolver(KinematicSolver): # override typing _entities: list[RigidEntity] = gs.List() @@ -425,6 +426,7 @@ def _should_transpose_constraint_layout(self) -> bool: def _build_static_config(self): static_rigid_sim_config = dict( backend=gs.backend, + is_qd_float_f32=(gs.qd_float == qd.f32), para_level=self.sim._para_level, requires_grad=self.sim.options.requires_grad, use_hibernation=self._use_hibernation, @@ -960,6 +962,150 @@ def _init_constraint_solver(self): self.constraint_solver = ConstraintSolverIsland(self) else: self.constraint_solver = ConstraintSolver(self) + self._contact_island_state = self.constraint_solver.contact_island.contact_island_state + self._collider_state = self.collider._collider_state + + @qd.kernel(fastcache=True) + def step_1( + self, + is_forward_pos_updated: qd.template(), + is_forward_vel_updated: qd.template(), + is_backward: qd.template(), + ): + if qd.static(not is_forward_pos_updated): + func_update_cartesian_space( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=is_backward, + ) + + if qd.static(not is_forward_vel_updated): + func_forward_velocity( + entities_info=self.entities_info, + links_info=self.links_info, + links_state=self.links_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=is_backward, + ) + + func_forward_dynamics( + links_state=self.links_state, + links_info=self.links_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + joints_info=self.joints_info, + entities_state=self.entities_state, + entities_info=self.entities_info, + geoms_state=self.geoms_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + contact_island_state=self._contact_island_state, + is_backward=is_backward, + ) + + @qd.kernel(fastcache=True) + def step_2( + self, + is_backward: qd.template(), + ): + func_update_acc( + update_cacc=True, + dofs_state=self.dofs_state, + links_info=self.links_info, + links_state=self.links_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=is_backward, + ) + + if qd.static(self._static_rigid_sim_config.integrator != gs.integrator.approximate_implicitfast): + func_implicit_damping( + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=is_backward, + ) + + func_integrate( + dofs_state=self.dofs_state, + links_info=self.links_info, + joints_info=self.joints_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=is_backward, + ) + + if qd.static(self._static_rigid_sim_config.use_hibernation): + func_hibernate__for_all_awake_islands_either_hiberanate_or_update_aabb_sort_buffer( + dofs_state=self.dofs_state, + entities_state=self.entities_state, + entities_info=self.entities_info, + links_state=self.links_state, + geoms_state=self.geoms_state, + collider_state=self._collider_state, + unused__rigid_global_info=self._rigid_global_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + contact_island_state=self._contact_island_state, + errno=self._errno, + ) + func_aggregate_awake_entities( + entities_state=self.entities_state, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) + + if qd.static(not is_backward): + func_copy_next_to_curr( + dofs_state=self.dofs_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + errno=self._errno, + ) + + if qd.static(not self._static_rigid_sim_config.enable_mujoco_compatibility): + func_update_cartesian_space( + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + force_update_fixed_geoms=False, + is_backward=is_backward, + ) + func_forward_velocity( + entities_info=self.entities_info, + links_info=self.links_info, + links_state=self.links_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + is_backward=is_backward, + ) def substep(self, f): # from genesis.utils.tools import create_timer @@ -974,20 +1120,7 @@ def substep(self, f): static_rigid_sim_config=self._static_rigid_sim_config, ) - kernel_step_1( - self.links_state, - self.links_info, - self.joints_state, - self.joints_info, - self.dofs_state, - self.dofs_info, - self.geoms_state, - self.geoms_info, - self.entities_state, - self.entities_info, - self._rigid_global_info, - self._static_rigid_sim_config, - self.constraint_solver.contact_island.contact_island_state, + self.step_1( self._is_forward_pos_updated, self._is_forward_vel_updated, self._is_backward, @@ -1002,23 +1135,8 @@ def substep(self, f): ) else: self._func_constraint_force() - kernel_step_2( - self.dofs_state, - self.dofs_info, - self.links_info, - self.links_state, - self.joints_info, - self.joints_state, - self.entities_state, - self.entities_info, - self.geoms_info, - self.geoms_state, - self.collider._collider_state, - self._rigid_global_info, - self._static_rigid_sim_config, - self.constraint_solver.contact_island.contact_island_state, + self.step_2( self._is_backward, - self._errno, ) self._is_forward_pos_updated = not self._enable_mujoco_compatibility self._is_forward_vel_updated = not self._enable_mujoco_compatibility @@ -1352,24 +1470,7 @@ def substep_pre_coupling_grad(self, f): if not is_grad_valid: gs.raise_exception(f"Nan grad in qpos or dofs_vel found at step {self._sim.cur_step_global}") - kernel_step_2.grad( - dofs_state=self.dofs_state, - dofs_info=self.dofs_info, - links_info=self.links_info, - links_state=self.links_state, - joints_info=self.joints_info, - joints_state=self.joints_state, - entities_state=self.entities_state, - entities_info=self.entities_info, - geoms_info=self.geoms_info, - geoms_state=self.geoms_state, - collider_state=self.collider._collider_state, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - contact_island_state=self.constraint_solver.contact_island.contact_island_state, - is_backward=True, - errno=self._errno, - ) + self.step_2.grad(is_backward=True) # We cannot use [kernel_forward_dynamics.grad] because we read [dofs_state.acc] and overwrite it in the kernel, # which is prohibited (https://docs.taichi-lang.org/docs/differentiable_programming#global-data-access-rules). @@ -1454,23 +1555,8 @@ def substep_post_coupling(self, f): static_rigid_sim_config=self._static_rigid_sim_config, is_backward=self._is_backward, ) - kernel_step_2( - dofs_state=self.dofs_state, - dofs_info=self.dofs_info, - links_info=self.links_info, - links_state=self.links_state, - joints_info=self.joints_info, - joints_state=self.joints_state, - entities_state=self.entities_state, - entities_info=self.entities_info, - geoms_info=self.geoms_info, - geoms_state=self.geoms_state, - collider_state=self.collider._collider_state, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - contact_island_state=self.constraint_solver.contact_island.contact_island_state, - is_backward=self._is_backward, - errno=self._errno, + self.step_2( + self._is_backward, ) elif isinstance(self.sim.coupler, IPCCoupler): # If any rigid entity is coupled to IPC, perform rigid simulation in post-coupling phase. @@ -2781,178 +2867,3 @@ def equalities(self): if self.is_built: return self._equalities return gs.List(equality for entity in self._entities for equality in entity.equalities) - - -@qd.kernel(fastcache=True) -def kernel_step_1( - links_state: array_class.LinksState, - links_info: array_class.LinksInfo, - joints_state: array_class.JointsState, - joints_info: array_class.JointsInfo, - dofs_state: array_class.DofsState, - dofs_info: array_class.DofsInfo, - geoms_state: array_class.GeomsState, - geoms_info: array_class.GeomsInfo, - entities_state: array_class.EntitiesState, - entities_info: array_class.EntitiesInfo, - rigid_global_info: array_class.RigidGlobalInfo, - static_rigid_sim_config: qd.template(), - contact_island_state: array_class.ContactIslandState, - is_forward_pos_updated: qd.template(), - is_forward_vel_updated: qd.template(), - is_backward: qd.template(), -): - if qd.static(not is_forward_pos_updated): - func_update_cartesian_space( - links_state=links_state, - links_info=links_info, - joints_state=joints_state, - joints_info=joints_info, - dofs_state=dofs_state, - dofs_info=dofs_info, - geoms_state=geoms_state, - geoms_info=geoms_info, - entities_info=entities_info, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - force_update_fixed_geoms=False, - is_backward=is_backward, - ) - - if qd.static(not is_forward_vel_updated): - func_forward_velocity( - entities_info=entities_info, - links_info=links_info, - links_state=links_state, - joints_info=joints_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, - ) - - func_forward_dynamics( - links_state=links_state, - links_info=links_info, - dofs_state=dofs_state, - dofs_info=dofs_info, - joints_info=joints_info, - entities_state=entities_state, - entities_info=entities_info, - geoms_state=geoms_state, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - contact_island_state=contact_island_state, - is_backward=is_backward, - ) - - -@qd.kernel(fastcache=True) -def kernel_step_2( - dofs_state: array_class.DofsState, - dofs_info: array_class.DofsInfo, - links_info: array_class.LinksInfo, - links_state: array_class.LinksState, - joints_info: array_class.JointsInfo, - joints_state: array_class.JointsState, - entities_state: array_class.EntitiesState, - entities_info: array_class.EntitiesInfo, - geoms_info: array_class.GeomsInfo, - geoms_state: array_class.GeomsState, - collider_state: array_class.ColliderState, - rigid_global_info: array_class.RigidGlobalInfo, - static_rigid_sim_config: qd.template(), - contact_island_state: array_class.ContactIslandState, - is_backward: qd.template(), - errno: qd.Tensor, -): - # Position, Velocity and Acceleration data must be consistent when computing links acceleration, otherwise it - # would not corresponds to anyting physical. There is no other way than doing this right before integration, - # because the acceleration at the end of the step is unknown for now as it may change discontinuous between - # before and after integration under the effect of external forces and constraints. This means that - # acceleration data will be shifted one timestep in the past, but there isn't really any way around. - func_update_acc( - update_cacc=True, - dofs_state=dofs_state, - links_info=links_info, - links_state=links_state, - entities_info=entities_info, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, - ) - - if qd.static(static_rigid_sim_config.integrator != gs.integrator.approximate_implicitfast): - func_implicit_damping( - dofs_state=dofs_state, - dofs_info=dofs_info, - entities_info=entities_info, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, - ) - - func_integrate( - dofs_state=dofs_state, - links_info=links_info, - joints_info=joints_info, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, - ) - - if qd.static(static_rigid_sim_config.use_hibernation): - func_hibernate__for_all_awake_islands_either_hiberanate_or_update_aabb_sort_buffer( - dofs_state=dofs_state, - entities_state=entities_state, - entities_info=entities_info, - links_state=links_state, - geoms_state=geoms_state, - collider_state=collider_state, - unused__rigid_global_info=rigid_global_info, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - contact_island_state=contact_island_state, - errno=errno, - ) - func_aggregate_awake_entities( - entities_state=entities_state, - entities_info=entities_info, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - ) - - if qd.static(not is_backward): - func_copy_next_to_curr( - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - errno=errno, - ) - - if qd.static(not static_rigid_sim_config.enable_mujoco_compatibility): - func_update_cartesian_space( - links_state=links_state, - links_info=links_info, - joints_state=joints_state, - joints_info=joints_info, - dofs_state=dofs_state, - dofs_info=dofs_info, - geoms_state=geoms_state, - geoms_info=geoms_info, - entities_info=entities_info, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - force_update_fixed_geoms=False, - is_backward=is_backward, - ) - func_forward_velocity( - entities_info=entities_info, - links_info=links_info, - links_state=links_state, - joints_info=joints_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - is_backward=is_backward, - ) diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index f5ce99bcb5..c51b43a44d 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -2051,6 +2051,10 @@ def get_rigid_adjoint_cache(solver): @qd.data_oriented class RigidSimStaticConfig(metaclass=AutoInitMeta): backend: int + # Whether ``gs.qd_float == qd.f32`` (i.e. single-precision floats). Declared on the static config + # so kernels can read it via ``qd.static(static_rigid_sim_config.is_qd_float_f32)`` instead of + # the module-level ``gs.qd_float``, keeping the fastcache key sensitive to precision. + is_qd_float_f32: bool para_level: int enable_collision: bool use_hibernation: bool diff --git a/pyproject.toml b/pyproject.toml index 2d30e22377..43ad821683 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ readme = "README.md" requires-python = ">=3.10,<3.14" dependencies = [ "psutil", - "quadrants==0.8.0", + "quadrants==1.0.1b2", "pydantic>=2.11.0", "numpy>=1.26.4", "frozendict", diff --git a/tests/test_grad.py b/tests/test_grad.py index d6276fd6e8..72a1560836 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -217,7 +217,6 @@ def compute_dL_error(dL_dx, x_type): @pytest.mark.parametrize("backend", [gs.cpu, gs.gpu]) def test_diff_solver(monkeypatch): from genesis.engine.solvers.rigid.constraint.solver import func_solve_init, func_solve_body - from genesis.engine.solvers.rigid.rigid_solver import kernel_step_1 RTOL = 1e-4 @@ -270,20 +269,7 @@ def constraint_solver_resolve(): # Step once to compute constraint solver's inputs: [mass], [jac], [aref], [efc_D], [force]. We do not call the # entire scene.step() because it will overwrite the necessary information that we need to compute the gradients. - kernel_step_1( - links_state=rigid_solver.links_state, - links_info=rigid_solver.links_info, - joints_state=rigid_solver.joints_state, - joints_info=rigid_solver.joints_info, - dofs_state=rigid_solver.dofs_state, - dofs_info=rigid_solver.dofs_info, - geoms_state=rigid_solver.geoms_state, - geoms_info=rigid_solver.geoms_info, - entities_state=rigid_solver.entities_state, - entities_info=rigid_solver.entities_info, - rigid_global_info=rigid_solver._rigid_global_info, - static_rigid_sim_config=rigid_solver._static_rigid_sim_config, - contact_island_state=constraint_solver.contact_island.contact_island_state, + rigid_solver.step_1( is_forward_pos_updated=True, is_forward_vel_updated=True, is_backward=False, diff --git a/tests/test_quadrants.py b/tests/test_quadrants.py index a953904c62..e45bc927a9 100644 --- a/tests/test_quadrants.py +++ b/tests/test_quadrants.py @@ -208,11 +208,11 @@ def gs_num_envs_child(args: list[str]): scene.rigid_solver.collider.detection() qd.sync() - from genesis.engine.solvers.rigid.rigid_solver import kernel_step_1 + from genesis.engine.solvers.rigid.rigid_solver import RigidSolver - assert kernel_step_1._primal.fe_ll_cache_observations.cache_hit == args.expected_fe_ll_cache_hit - assert kernel_step_1._primal.src_ll_cache_observations.cache_key_generated == args.expected_use_src_ll_cache - assert kernel_step_1._primal.src_ll_cache_observations.cache_loaded == args.expected_src_ll_cache_hit + assert RigidSolver.step_1._primal.fe_ll_cache_observations.cache_hit == args.expected_fe_ll_cache_hit + assert RigidSolver.step_1._primal.src_ll_cache_observations.cache_key_generated == args.expected_use_src_ll_cache + assert RigidSolver.step_1._primal.src_ll_cache_observations.cache_loaded == args.expected_src_ll_cache_hit sys.exit(RET_SUCCESS) @@ -308,10 +308,10 @@ def change_scene(args: list[str]): z = qpos.reshape((*qpos.shape[:-1], args.n_objs, 7))[..., 2] assert_allclose(z, 0.2, atol=1e-3) - from genesis.engine.solvers.rigid.rigid_solver import kernel_step_1 + from genesis.engine.solvers.rigid.rigid_solver import RigidSolver - assert kernel_step_1._primal.src_ll_cache_observations.cache_validated == args.expected_src_ll_cache_hit - assert kernel_step_1._primal.src_ll_cache_observations.cache_loaded == args.expected_src_ll_cache_hit + assert RigidSolver.step_1._primal.src_ll_cache_observations.cache_validated == args.expected_src_ll_cache_hit + assert RigidSolver.step_1._primal.src_ll_cache_observations.cache_loaded == args.expected_src_ll_cache_hit sys.exit(RET_SUCCESS)