Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions genesis/assets/xml/cartpole.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<mujoco model="cartpole">
<option gravity="0 0 -9.81"/>
<worldbody>
<body name="cart" pos="0 0 0">
<joint name="slider" type="slide" axis="1 0 0" range="-4 4" damping="0.0"/>
<inertial pos="0 0 0" mass="1.0" diaginertia="1.0 1.0 1.0"/>
<geom name="cart_geom" type="box" size="0.25 0.25 0.1" contype="0" conaffinity="0" rgba="0 0 0.8 1"/>
<body name="pole" pos="0 0 0">
<joint name="hinge" type="hinge" axis="0 1 0" damping="0.0"/>
<inertial pos="0 0 0.5" mass="10.0" diaginertia="1.0 1.0 1.0"/>
<geom name="pole_geom" type="box" pos="0 0 0.5" size="0.025 0.025 0.5" contype="0" conaffinity="0" rgba="1 1 1 1"/>
</body>
</body>
</worldbody>
</mujoco>
9 changes: 9 additions & 0 deletions genesis/assets/xml/grad/capsule.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<mujoco model="capsule">
<compiler angle="degree"/>
<worldbody>
<body name="capsule" pos="0 0 0">
<geom type="capsule" size="0.1 0.2"/>
<joint name="capsule_joint" type="free"/>
</body>
</worldbody>
</mujoco>
9 changes: 9 additions & 0 deletions genesis/assets/xml/grad/free.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<mujoco model="free">
<worldbody>
<body name="chassis" pos="0 0 0">
<freejoint/>
<inertial mass="1.0" pos="0 0 0" diaginertia="0.1 0.1 0.1"/>
<geom type="box" size="0.1 0.1 0.1" contype="0" conaffinity="0"/>
</body>
</worldbody>
</mujoco>
14 changes: 14 additions & 0 deletions genesis/assets/xml/grad/free_with_revolute.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
<mujoco model="free_with_child">
<worldbody>
<body name="chassis" pos="0 0 0">
<freejoint/>
<inertial mass="1.0" pos="0 0 0" diaginertia="0.1 0.1 0.1"/>
<geom type="box" size="0.1 0.1 0.1" contype="0" conaffinity="0"/>
<body name="arm" pos="0.2 0 0">
<joint type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</body>
</worldbody>
</mujoco>
9 changes: 9 additions & 0 deletions genesis/assets/xml/grad/prismatic.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<mujoco model="prismatic">
<worldbody>
<body name="slider" pos="0 0 0">
<joint type="slide" axis="1 0 0"/>
<inertial mass="0.5" pos="0 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="box" size="0.05 0.05 0.05" contype="0" conaffinity="0"/>
</body>
</worldbody>
</mujoco>
9 changes: 9 additions & 0 deletions genesis/assets/xml/grad/revolute.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<mujoco model="revolute">
<worldbody>
<body name="arm" pos="0 0 0">
<joint type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</worldbody>
</mujoco>
19 changes: 19 additions & 0 deletions genesis/assets/xml/grad/revolute_chain3.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<mujoco model="chain3">
<worldbody>
<body name="l1" pos="0 0 0">
<joint type="hinge" axis="0 1 0"/>
<inertial mass="0.3" pos="0.1 0 0" diaginertia="0.005 0.005 0.005"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
<body name="l2" pos="0.2 0 0">
<joint type="hinge" axis="0 1 0"/>
<inertial mass="0.3" pos="0.1 0 0" diaginertia="0.005 0.005 0.005"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
<body name="l3" pos="0.2 0 0">
<joint type="hinge" axis="0 1 0"/>
<inertial mass="0.3" pos="0.1 0 0" diaginertia="0.005 0.005 0.005"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</body>
</body>
</worldbody>
</mujoco>
10 changes: 10 additions & 0 deletions genesis/assets/xml/grad/slider_limit.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
<mujoco model="slider_limit">
<option gravity="0 0 0"/>
<worldbody>
<body name="cart" pos="0 0 0">
<joint name="slider" type="slide" axis="1 0 0" range="-4 4" damping="0.0"/>
<inertial pos="0 0 0" mass="1.0" diaginertia="1.0 1.0 1.0"/>
<geom type="box" size="0.25 0.25 0.1" contype="0" conaffinity="0"/>
</body>
</worldbody>
</mujoco>
9 changes: 9 additions & 0 deletions genesis/assets/xml/grad/spherical.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<mujoco model="spherical">
<worldbody>
<body name="ball" pos="0 0 0">
<joint type="ball"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</worldbody>
</mujoco>
27 changes: 27 additions & 0 deletions genesis/assets/xml/hopper.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
<mujoco model="hopper">
<compiler angle="radian"/>
<default>
<joint limited="true" armature="1" damping="1"/>
<geom condim="3" friction="0.9 0.005 0.0001" rgba="0.8 0.6 0.4 1"/>
</default>
<worldbody>
<body name="torso" pos="0 0 1.25">
<joint name="rootx" pos="0 0 0" axis="1 0 0" type="slide" limited="false" armature="0" damping="0"/>
<joint name="rootz" pos="0 0 0" axis="0 0 1" type="slide" limited="false" armature="0" damping="0"/>
<joint name="rooty" pos="0 0 0" axis="0 1 0" type="hinge" limited="false" armature="0" damping="0"/>
<geom name="torso_geom" type="capsule" size="0.05 0.2"/>
<body name="thigh" pos="0 0 -0.2">
<joint name="thigh_joint" pos="0 0 0" axis="0 -1 0" type="hinge" range="-2.61799 0"/>
<geom name="thigh_geom" type="capsule" size="0.05 0.225" pos="0 0 -0.225"/>
<body name="leg" pos="0 0 -0.7">
<joint name="leg_joint" pos="0 0 0.25" axis="0 -1 0" type="hinge" range="-2.61799 0"/>
<geom name="leg_geom" type="capsule" size="0.04 0.25"/>
<body name="foot" pos="0 0 -0.25">
<joint name="foot_joint" pos="0 0 0" axis="0 -1 0" type="hinge" range="-0.785398 0.785398"/>
<geom name="foot_geom" type="capsule" size="0.06 0.195" pos="0.06 0 0" quat="0.707107 0 -0.707107 0" friction="2 0.005 0.0001"/>
</body>
</body>
</body>
</body>
</worldbody>
</mujoco>
19 changes: 18 additions & 1 deletion genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
self._load_model()

# Initialize target variables and checkpoint
self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity")
self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity", "control_dofs_force")
self._tgt = dict()
self._tgt_buffer = list()
self._ckpt = dict()
Expand Down Expand Up @@ -1156,6 +1156,8 @@ def process_input(self, in_backward=False):
self.set_quat(**data_kwargs)
case "set_dofs_velocity":
self.set_dofs_velocity(**data_kwargs)
case "control_dofs_force":
self.control_dofs_force(**data_kwargs)
case _:
gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}")

Expand Down Expand Up @@ -1188,6 +1190,15 @@ def process_input_grad(self):
data_kwargs["dofs_idx_local"],
data_kwargs["envs_idx"],
)

case "control_dofs_force":
force = data_kwargs.pop("force")
if force is not None and force.requires_grad:
force._backward_from_qd(
self.set_dofs_force_grad,
data_kwargs["dofs_idx_local"],
data_kwargs["envs_idx"],
)
case _:
gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}")

Expand Down Expand Up @@ -3559,6 +3570,11 @@ def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, velocity_grad):
dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, velocity_grad.data)

@gs.assert_built
def set_dofs_force_grad(self, dofs_idx_local, envs_idx, force_grad):
dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.set_dofs_force_grad(dofs_idx, envs_idx, force_grad.data)

# ------------------------------------------------------------------------------------
# ----------------------------- DOF property setters ---------------------------------
# ------------------------------------------------------------------------------------
Expand Down Expand Up @@ -3592,6 +3608,7 @@ def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zer
# ------------------------------------------------------------------------------------

@gs.assert_built
@tracked
def control_dofs_force(self, force, dofs_idx_local=None, envs_idx=None):
"""
Control the entity's dofs' motor force. This is used for force/torque control.
Expand Down
50 changes: 48 additions & 2 deletions genesis/engine/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,13 +977,16 @@ def reset(self, state: SimState | None = None, envs_idx=None):
self._reset(state, envs_idx=envs_idx)
self._recorder_manager.reset(envs_idx)

def _reset(self, state: SimState | None = None, *, envs_idx=None):
def _reset(self, state: SimState | None = None, *, envs_idx=None, keep_init: bool = False):
if self._is_built:
if state is None:
state = self._init_state
else:
assert isinstance(state, SimState), "state must be a SimState object"
self._init_state = state
# `keep_init=True` restores the state without making it the new
# init, so a later bare `reset()` still rewinds to the true init.
if not keep_init:
self._init_state = state
self._sim.reset(state, envs_idx)
else:
self._init_state = self._get_state()
Expand All @@ -1004,6 +1007,49 @@ def _reset(self, state: SimState | None = None, *, envs_idx=None):
def _reset_grad(self):
self._backward_ready = True

@gs.assert_built
def backward(self, loss: torch.Tensor, *args, **kwargs):
"""Differentiate `loss` and restore the terminal physics state.

Wraps the snapshot/backward/restore dance that differentiable rollouts
otherwise have to perform by hand. `scene._backward()` rewinds physics
state to step 0 as a side-effect of unrolling the adstack, so the safe
pattern is to snapshot the terminal state *before* backward and restore
it *after*:

snapshot = scene.get_state() # terminal state
loss.backward() # rewinds physics to step 0
scene.reset(snapshot) # restore + clear grads + re-arm

This method does exactly that, so callers can just write
`scene.backward(loss)`. Afterwards the scene sits at the terminal physics
state with grads cleared and forward/backward re-armed — ready to continue
the rollout or to be reset to a fresh init.

The registered initial state (`reset()` with no args) is left untouched.

Parameters
----------
loss : torch.Tensor
Scalar loss to differentiate. Extra args/kwargs (e.g. `gradient`,
`retain_graph`) are forwarded to `torch.autograd.backward`.
"""
# Snapshot the terminal state before backward rewinds physics to step 0.
snapshot = self.get_state()
# `scene._backward()` re-enters the torch graph from each step's queried
# states (`_backward_from_qd` -> `state.backward(retain_graph=True)`), so
# the graph must survive the initial autograd pass.
kwargs.setdefault("retain_graph", True)
# Functional `torch.autograd.backward` fills torch + queried-state grads
# WITHOUT triggering `gs.Tensor.backward`'s auto `scene._backward()`, so
# we drive the sim unroll explicitly below.
torch.autograd.backward(loss, *args, **kwargs)
self._backward()
# Restore to the terminal snapshot; `keep_init=True` preserves the real
# initial state so a later bare `reset()` still rewinds to it.
self._reset(snapshot, keep_init=True)
return snapshot

def _get_state(self):
return self._sim.get_state()

Expand Down
9 changes: 9 additions & 0 deletions genesis/engine/solvers/kinematic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
kernel_set_dofs_position,
kernel_set_dofs_velocity,
kernel_set_dofs_velocity_grad,
kernel_set_dofs_force_grad,
kernel_set_dofs_zero_velocity,
kernel_set_links_pos,
kernel_set_links_quat,
Expand Down Expand Up @@ -951,6 +952,14 @@ def set_dofs_velocity_grad(self, dofs_idx, envs_idx, velocity_grad):
velocity_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config
)

def set_dofs_force_grad(self, dofs_idx, envs_idx, force_grad):
force_grad_, dofs_idx, envs_idx = self._sanitize_io_variables(
force_grad, dofs_idx, self.n_dofs, "dofs_idx", envs_idx, skip_allocation=True
)
if self.n_envs == 0:
force_grad_ = force_grad_.unsqueeze(0)
kernel_set_dofs_force_grad(force_grad_, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config)

def set_dofs_position(self, position, dofs_idx=None, envs_idx=None):
position, dofs_idx, envs_idx = self._sanitize_io_variables(
position, dofs_idx, self.n_dofs, "dofs_idx", envs_idx, skip_allocation=True
Expand Down
14 changes: 14 additions & 0 deletions genesis/engine/solvers/rigid/abd/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,20 @@ def kernel_set_dofs_zero_velocity(
dofs_state.vel[dofs_idx[i_d_], envs_idx[i_b_]] = 0.0


@qd.kernel(fastcache=True)
def kernel_set_dofs_force_grad(
force_grad: qd.types.ndarray(),
dofs_idx: qd.types.ndarray(),
envs_idx: qd.types.ndarray(),
dofs_state: array_class.DofsState,
static_rigid_sim_config: qd.template(),
):
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)
for i_d_, i_b_ in qd.ndrange(dofs_idx.shape[0], envs_idx.shape[0]):
force_grad[i_b_, i_d_] = dofs_state.ctrl_force.grad[dofs_idx[i_d_], envs_idx[i_b_]]
dofs_state.ctrl_force.grad[dofs_idx[i_d_], envs_idx[i_b_]] = 0.0


@qd.kernel(fastcache=True)
def kernel_set_dofs_position(
position: qd.types.ndarray(),
Expand Down
53 changes: 42 additions & 11 deletions genesis/engine/solvers/rigid/abd/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import genesis as gs
import genesis.utils.geom as gu
import genesis.utils.array_class as array_class
from .forward_kinematics import func_update_cartesian_space
from .forward_kinematics import func_update_cartesian_space, func_forward_velocity


@qd.func
Expand Down Expand Up @@ -181,6 +181,16 @@ def kernel_prepare_backward_substep(
force_update_fixed_geoms=False,
is_backward=True,
)
func_forward_velocity(
entities_info=entities_info,
links_info=links_info,
links_state=links_state,
joints_info=joints_info,
dofs_state=dofs_state,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
is_backward=True,
)

# FIXME: Parameter pruning for ndarray is buggy for now and requires match variable and arg names.
# Save results of [update_cartesian_space] to adjoint cache
Expand Down Expand Up @@ -232,17 +242,17 @@ def kernel_begin_backward_substep(
)

if not static_rigid_sim_config.enable_mujoco_compatibility:
# FIXME: Parameter pruning for ndarray is buggy for now and requires match variable and arg names.
# Save results of [update_cartesian_space] to adjoint cache
# Restore the cartesian space that was overwritten by
# post-integrate forward replay in the backward substep
func_copy_cartesian_space(
dofs_state=dofs_state,
links_state=links_state,
joints_state=joints_state,
geoms_state=geoms_state,
dofs_state_adjoint_cache=dofs_state_adjoint_cache,
links_state_adjoint_cache=links_state_adjoint_cache,
joints_state_adjoint_cache=joints_state_adjoint_cache,
geoms_state_adjoint_cache=geoms_state_adjoint_cache,
dofs_state=dofs_state_adjoint_cache,
links_state=links_state_adjoint_cache,
joints_state=joints_state_adjoint_cache,
geoms_state=geoms_state_adjoint_cache,
dofs_state_adjoint_cache=dofs_state,
links_state_adjoint_cache=links_state,
joints_state_adjoint_cache=joints_state,
geoms_state_adjoint_cache=geoms_state,
static_rigid_sim_config=static_rigid_sim_config,
)

Expand Down Expand Up @@ -347,6 +357,27 @@ def kernel_copy_acc(
dofs_state.acc[i_d, i_b] = rigid_adjoint_cache.dofs_acc[f, i_d, i_b]


@qd.kernel(fastcache=True)
def kernel_copy_next_to_curr_no_check(
dofs_state: array_class.DofsState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
):
# Unguarded copy of `_next` slots to current. Used in the backward substep right before
# the forward replay so the BW kernels see the post-integrate qpos / vel.
n_qs = rigid_global_info.qpos.shape[0]
n_dofs = dofs_state.vel.shape[0]
_B = dofs_state.vel.shape[1]

qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_q, i_b in qd.ndrange(n_qs, _B):
rigid_global_info.qpos[i_q, i_b] = rigid_global_info.qpos_next[i_q, i_b]

qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in qd.ndrange(n_dofs, _B):
dofs_state.vel[i_d, i_b] = dofs_state.vel_next[i_d, i_b]


@qd.func
def func_integrate_dq_entity(
dq,
Expand Down
Loading
Loading