Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,21 +1346,60 @@ def get_pos(self, envs_idx=None):
return self._solver.get_links_pos(self.base_link_idx, envs_idx)[..., 0, :]

@gs.assert_built
def get_quat(self, envs_idx=None):
def get_quat(self, envs_idx=None, *, relative=False):
"""
Returns quaternion of the entity's base link.

Parameters
----------
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
relative : bool, optional
If True, return the quaternion relative to the initial (not current!) quaternion.
The returned quaternion ``delta`` satisfies
``abs_quat == transform_quat_by_quat(init_quat, delta)``.
Equivalently, ``delta == transform_quat_by_quat(inv_quat(init_quat), abs_quat)``.
Defaults to False.

Returns
-------
quat : torch.Tensor, shape (4,) or (n_envs, 4)
The quaternion of the entity's base link.
"""
return self._solver.get_links_quat(self.base_link_idx, envs_idx)[..., 0, :]
The quaternion of the entity's base link (absolute or relative).
"""
abs_quat = self._solver.get_links_quat(self.base_link_idx, envs_idx)[..., 0, :]
if not relative:
return abs_quat

has_free_root_qpos = self.base_link.n_joints == 1 and self.base_link.joints[0].type == gs.JOINT_TYPE.FREE
if not has_free_root_qpos:
if self._solver._options.batch_links_info:
init_quat = qd_to_torch(
self._solver.links_info.quat,
envs_idx,
self.base_link_idx,
transpose=True,
copy=True,
)
if self._solver.n_envs == 0:
init_quat = init_quat[0, 0]
else:
init_quat = init_quat[:, 0]
else:
init_quat = torch.as_tensor(self.base_link.quat, dtype=abs_quat.dtype, device=abs_quat.device)
else:
q_start = self.base_link.q_start
init_quat = qd_to_torch(
self._solver.qpos0,
envs_idx,
slice(q_start + 3, q_start + 7),
transpose=True,
copy=True,
)
if self._solver.n_envs == 0:
init_quat = init_quat[0]

init_quat = init_quat.to(dtype=abs_quat.dtype, device=abs_quat.device)
return gu.transform_quat_by_quat(gu.inv_quat(init_quat), abs_quat)

@gs.assert_built
def get_vel(self, envs_idx=None):
Expand Down
83 changes: 83 additions & 0 deletions tests/test_rigid_physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,85 @@ def test_set_root_pose(batch_fixed_verts, relative, show_viewer, tol):
quat_ref = quat_delta
assert_allclose(quat, quat_ref, tol=tol)

if relative:
quat_rel_ref = quat_delta
else:
quat_rel_ref = gu.transform_quat_by_quat(gu.inv_quat(quat_zero), quat_delta)
assert_allclose(entity.get_quat(relative=True), quat_rel_ref, tol=tol)
# Verify get_quat(relative=False) matches get_quat() (preserves old behavior)
assert_allclose(entity.get_quat(relative=False), quat, tol=tol)


@pytest.mark.required
def test_get_quat_relative_heterogeneous_initial_quat(show_viewer, tol):
scene = gs.Scene(
rigid_options=gs.options.RigidOptions(batch_links_info=True),
show_viewer=show_viewer,
show_FPS=False,
)
box = scene.add_entity(
morph=(
gs.morphs.Box(size=(0.04, 0.04, 0.04), pos=(0.0, 0.0, 0.1), euler=(0.0, 0.0, 0.0)),
gs.morphs.Box(size=(0.04, 0.04, 0.04), pos=(0.0, 0.0, 0.1), euler=(0.0, 45.0, 0.0)),
),
)
scene.build(n_envs=4)

quat_delta = torch.tensor(
[
[0.9238795, 0.3826834, 0.0, 0.0],
[0.8660254, 0.0, 0.5, 0.0],
[0.7071068, 0.0, 0.0, 0.7071068],
[1.0, 0.0, 0.0, 0.0],
],
dtype=gs.tc_float,
device=gs.device,
)
quat_delta = quat_delta / torch.linalg.norm(quat_delta, dim=-1, keepdim=True)

box.set_quat(quat_delta, relative=True)

assert_allclose(box.get_quat(relative=True), quat_delta, tol=tol)
assert_allclose(box.get_quat(envs_idx=[2, 3], relative=True), quat_delta[2:], tol=tol)


@pytest.mark.required
def test_get_quat_relative_non_parallel(show_viewer, tol):
scene = gs.Scene(show_viewer=show_viewer, show_FPS=False)
box = scene.add_entity(gs.morphs.Box(size=(0.04, 0.04, 0.04), pos=(0.0, 0.0, 0.1), euler=(0.0, 30.0, 0.0)))
scene.build()

quat_delta = torch.tensor([0.9238795, 0.0, 0.3826834, 0.0], dtype=gs.tc_float, device=gs.device)
quat_delta = quat_delta / torch.linalg.norm(quat_delta)

box.set_quat(quat_delta, relative=True)
quat_rel = box.get_quat(relative=True)
assert quat_rel.shape == quat_delta.shape
assert_allclose(quat_rel, quat_delta, tol=tol)


@pytest.mark.required
def test_get_quat_relative_non_parallel_batched_link_info(show_viewer, tol):
scene = gs.Scene(
rigid_options=gs.options.RigidOptions(batch_links_info=True),
show_viewer=show_viewer,
show_FPS=False,
)
box = scene.add_entity(
gs.morphs.Box(
fixed=True,
batch_fixed_verts=True,
size=(0.04, 0.04, 0.04),
pos=(0.0, 0.0, 0.1),
euler=(0.0, 30.0, 0.0),
)
)
scene.build()

quat_rel = box.get_quat(relative=True)
assert quat_rel.shape == (4,)
assert_allclose(quat_rel, gu.identity_quat(), tol=tol)


@pytest.mark.required
def test_normalized_quat(show_viewer, tol):
Expand Down Expand Up @@ -5062,6 +5141,10 @@ def test_merge_entities(is_fixed, merge_fixed_links, show_viewer, tol, monkeypat

attach_link = franka.get_link("attachment")
assert_allclose(attach_link.get_pos(), hand.links[0].get_pos(), tol=gs.EPS)
hand_quat_rel = hand.get_quat(relative=True)
hand_init_quat = torch.as_tensor(hand.base_link.quat, dtype=gs.tc_float, device=gs.device)
hand_quat_rel_ref = gu.transform_quat_by_quat(gu.inv_quat(hand_init_quat), hand.get_quat())
assert_allclose(hand_quat_rel, hand_quat_rel_ref, tol=tol)
offset_quat = gu.transform_quat_by_quat(hand.links[0].get_quat(), gu.inv_quat(attach_link.get_quat()))
assert_allclose(gu.quat_to_xyz(offset_quat, rpy=False, degrees=True), EULER_OFFSET, tol=tol)
for link in hand.links[slice(0, None) if merge_fixed_links else slice(1, -1)]:
Expand Down