Skip to content
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
4e7862c
primary impl
alanray-tech Nov 13, 2025
e5d699f
seems successful
alanray-tech Nov 14, 2025
dd257bd
able to load g1
alanray-tech Nov 19, 2025
97d4c8c
[FEATURE] Add USD stage import functionality and related parsing util…
alanray-tech Nov 20, 2025
b85d994
temp
alanray-tech Nov 25, 2025
0a44794
Merge branch 'Genesis-Embodied-AI:main' into main
alanray-tech Nov 25, 2025
8c88436
Merge remote-tracking branch 'origin/main' into dev
alanray-tech Nov 25, 2025
f652a3a
new structure
alanray-tech Nov 30, 2025
8814e38
add free root joint support
alanray-tech Dec 2, 2025
2e47e34
fix free joint base link init transform
alanray-tech Dec 3, 2025
b0727c8
refactor code to better structure
alanray-tech Dec 3, 2025
4d31b02
Merge branch 'main' into dev
alanray-tech Dec 3, 2025
9fc5db0
add missing usd-parser-related import
alanray-tech Dec 3, 2025
74689d2
support usd optional import
alanray-tech Dec 3, 2025
7a44ac2
add missing morphs.Drone
alanray-tech Dec 4, 2025
303f84f
support assets-download/argparser in usd example, clean up deps
alanray-tech Dec 8, 2025
e51f975
add usd import stage example to unit test
alanray-tech Dec 8, 2025
1f4633e
refactor parser
alanray-tech Dec 8, 2025
cabdd49
fix uv_name missing
alanray-tech Dec 8, 2025
7b0c544
fix rotation scaling extract
alanray-tech Dec 9, 2025
0b23d9a
add usd driver api support
alanray-tech Dec 10, 2025
0f5d298
update doc
alanray-tech Dec 10, 2025
606752d
update workflow, install usd for usd unit test
alanray-tech Dec 10, 2025
83675f1
add a simple animator, add default value init for dofs_frictionloss/d…
alanray-tech Dec 10, 2025
54443f0
add morph option
YilingQiao Dec 10, 2025
82b3533
Merge pull request #1 from YilingQiao/yiling/251210_usd_collision_vis…
alanray-tech Dec 11, 2025
7d64715
Merge remote-tracking branch 'pub/main' into dev
alanray-tech Dec 12, 2025
a3fee62
weird target behaviour, need fix
alanray-tech Dec 12, 2025
dbc08ca
Merge branch 'main' into dev
YilingQiao Dec 14, 2025
bb87147
set target
YilingQiao Dec 16, 2025
c7cfa53
update limit
YilingQiao Dec 17, 2025
8f02db1
add target to dofs info
YilingQiao Dec 23, 2025
886dfe6
Merge pull request #2 from YilingQiao/yiling/251216_change_target
alanray-tech Dec 23, 2025
7820f73
Merge remote-tracking branch 'pub/main' into dev
alanray-tech Dec 28, 2025
f9a037f
update pyproject.toml
alanray-tech Dec 29, 2025
e72dff4
merge origin/dev
alanray-tech Dec 29, 2025
28ba6ce
change damping
YilingQiao Dec 29, 2025
2290640
Merge pull request #3 from YilingQiao/yiling/251229_change_damping
alanray-tech Dec 29, 2025
9e24f49
Merge remote-tracking branch 'pub/main' into dev
alanray-tech Dec 30, 2025
0a8611d
fix rigid_solver_decomp missing entity_idx, which crash the rigid sim…
alanray-tech Dec 30, 2025
eba954a
Fixed the import path of .usda
alanray-tech Dec 30, 2025
a47bae5
fix usd_parser, move Entity type import into the function, so that a …
alanray-tech Dec 31, 2025
6a71764
try skip usd related test on ARM machine
alanray-tech Jan 3, 2026
45c5119
Merge branch 'main' into dev
alanray-tech Jan 4, 2026
fd4dda0
make usd import optional
alanray-tech Jan 4, 2026
e973715
Merge remote-tracking branch 'pub/main' into dev
alanray-tech Jan 5, 2026
e587af0
Merge branch 'main' into dev
alanray-tech Jan 6, 2026
ea45d00
Merge branch 'main' into dev
alanray-tech Jan 6, 2026
10a38a3
improve at api and code style level according to the review
alanray-tech Jan 8, 2026
feb0ebe
refactor parsing logic & clean up codes for better readability and pe…
alanray-tech Jan 9, 2026
3a82ffd
improve docstring
alanray-tech Jan 9, 2026
016909b
try fix CI
alanray-tech Jan 9, 2026
179be0e
fix drone test
alanray-tech Jan 9, 2026
9b0fbe2
unify the usd rigidbody and articulation parser to rigid entity parse…
alanray-tech Jan 10, 2026
c443129
remove compute_joint_axis_scaling_factor, because USD take this value…
alanray-tech Jan 10, 2026
dcea807
Merge branch 'main' into dev
alanray-tech Jan 10, 2026
0b54773
try fix NV EULA agreement input
alanray-tech Jan 11, 2026
e817bda
merge
alanray-tech Jan 11, 2026
5bab2c6
add some unit tests, not fully implemented
alanray-tech Jan 12, 2026
778b92b
add PureRigid/Revoluate/Prismatic/Spherical unit tests
alanray-tech Jan 13, 2026
d470b10
Adding OMNI_KIT_ACCEPT_EULA to the workflows that use USD; Improve ex…
alanray-tech Jan 13, 2026
c84b358
Merge branch 'dev' of https://github.com/alanray-tech/Genesis into dev
alanray-tech Jan 13, 2026
46a025c
use float32 compatible tol in test_usd
alanray-tech Jan 13, 2026
ff159d3
expose more options in USD(Morph), align them with MJCF, test passed
alanray-tech Jan 13, 2026
b111c67
clean up code, add better warning when attribute matching fails
alanray-tech Jan 14, 2026
f7bbd6a
Merge remote-tracking branch 'pub/main' into dev
alanray-tech Jan 14, 2026
86c5def
change warning about attribute missing to debug
alanray-tech Jan 14, 2026
131147c
include OMNI_KIT_ACCEPT_EULA and OMNI_KIT_ALLOW_ROOT in SLURM_ENV_VAR…
alanray-tech Jan 14, 2026
6ce56ec
support batched polar (torch and numpy), add filter function in polar…
alanray-tech Jan 14, 2026
38dd302
update tol
alanray-tech Jan 14, 2026
79a25ad
Merge remote-tracking branch 'pub/main' into dev
alanray-tech Jan 15, 2026
4c5f77c
Final cleanup.
duburcqa Jan 15, 2026
a4d64a7
Merge branch 'main' into dev
duburcqa Jan 15, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
run: |
pip install --upgrade pip setuptools wheel
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install -e '.[dev]' pynput
pip install -e '.[dev,usd]' pynput

- name: Get gstaichi version
id: gstaichi_version
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/generic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ jobs:
shell: bash
run: |
PYTHON_DEPS="dev"
# Install USD for all platforms except ARM (usd-core doesn't support ARM)
# This is required for test_mesh.py which tests USD parsing functionality
if [[ "${{ matrix.OS }}" != 'ubuntu-24.04-arm' ]] ; then
PYTHON_DEPS="${PYTHON_DEPS},usd"
fi
Expand Down
102 changes: 102 additions & 0 deletions examples/usd/import_stage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import argparse
import os

import numpy as np
from huggingface_hub import snapshot_download

import genesis as gs
from genesis.utils.misc import ti_to_numpy


class JointAnimator:
"""
A simple JointAnimator to animate the joints' positions of the scene.
It uses the sin function to interpolate between the lower and upper limits of the joints.
"""

def __init__(self, scene: gs.Scene):
self.rigid = scene.sim.rigid_solver
joint_limits = ti_to_numpy(self.rigid.dofs_info.limit)
joint_limits = np.clip(joint_limits, -np.pi, np.pi)

init_positions = self.rigid.get_dofs_position().numpy()

self.joint_lower = joint_limits[:, 0]
self.joint_upper = joint_limits[:, 1]

valid_range_mask = (self.joint_upper - self.joint_lower) > gs.EPS

normalized_init_pos = np.where(
valid_range_mask,
2.0 * (init_positions - self.joint_lower) / (self.joint_upper - self.joint_lower) - 1.0,
0.0,
)
self.init_phase = np.arcsin(normalized_init_pos)

def animate(self, scene: gs.Scene):
"""Calculate target positions using sin function to interpolate between lower and upper limits"""
t = scene.t * scene.dt
theta = np.pi * t + self.init_phase
theta = theta % (2 * np.pi)
sin_values = np.sin(theta)
normalized = (sin_values + 1.0) / 2.0
target = self.joint_lower + (self.joint_upper - self.joint_lower) * normalized
self.rigid.control_dofs_position(target)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--num_steps", type=int, default=1000)
parser.add_argument("-v", "--vis", action="store_true", default=False)
args = parser.parse_args()

args.num_steps = 1 if "PYTEST_VERSION" in os.environ else args.num_steps
args.vis = False if "PYTEST_VERSION" in os.environ else args.vis

gs.init(backend=gs.cpu)

dt = 0.002
scene = gs.Scene(
viewer_options=gs.options.ViewerOptions(
camera_pos=(3.5, 0.0, 2.5),
camera_lookat=(0.0, 0.0, 0.5),
camera_fov=40,
),
rigid_options=gs.options.RigidOptions(
dt=dt,
gravity=(0, 0, -9.8),
enable_collision=True,
enable_joint_limit=True,
max_collision_pairs=1000,
),
show_viewer=args.vis,
)

asset_path = snapshot_download(
repo_type="dataset",
repo_id="Genesis-Intelligence/assets",
revision="main",
allow_patterns="usd/Lightwheel_Kitchen001/Kitchen001/*",
max_workers=1,
)

entities = scene.add_stage(
morph=gs.morphs.USD(
file=f"{asset_path}/usd/Lightwheel_Kitchen001/Kitchen001/Kitchen001.usd",
),
# vis_mode="collision",
# visualize_contact=True,
)

scene.build()

joint_animator = JointAnimator(scene)

for _ in range(args.num_steps):
joint_animator.animate(scene)
scene.step()


if __name__ == "__main__":
main()
9 changes: 6 additions & 3 deletions genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _load_model(self):

if isinstance(self._morph, gs.morphs.Mesh):
self._load_mesh(self._morph, self._surface)
elif isinstance(self._morph, (gs.morphs.MJCF, gs.morphs.URDF, gs.morphs.Drone)):
elif isinstance(self._morph, (gs.morphs.MJCF, gs.morphs.URDF, gs.morphs.Drone, gs.morphs.USD)):
self._load_scene(self._morph, self._surface)
elif isinstance(self._morph, gs.morphs.Primitive):
self._load_primitive(self._morph, self._surface)
Expand Down Expand Up @@ -371,11 +371,10 @@ def _load_scene(self, morph, surface):
if isinstance(morph, gs.morphs.MJCF):
# Mujoco's unified MJCF+URDF parser systematically for MJCF files
l_infos, links_j_infos, links_g_infos, eqs_info = mju.parse_xml(morph, surface)
else:
elif isinstance(morph, (gs.morphs.URDF, gs.morphs.Drone)):
# Custom "legacy" URDF parser for loading geometries (visual and collision) and equality constraints.
# This is necessary because Mujoco cannot parse visual geometries (meshes) reliably for URDF.
l_infos, links_j_infos, links_g_infos, eqs_info = uu.parse_urdf(morph, surface)

# Mujoco's unified MJCF+URDF parser for only link, joints, and collision geometries properties.
morph_ = copy(morph)
morph_.visualization = False
Expand Down Expand Up @@ -416,7 +415,11 @@ def _load_scene(self, morph, surface):
link_g_infos.append(g_info)
except (ValueError, AssertionError):
gs.logger.info("Falling back to legacy URDF parser. Default values of physics properties may be off.")
elif isinstance(morph, gs.morphs.USD):
from genesis.utils.usd import parse_usd_rigid_entity

# Unified parser handles both articulations and rigid bodies
l_infos, links_j_infos, links_g_infos, eqs_info = parse_usd_rigid_entity(morph, surface)
# Add free floating joint at root if necessary
if (
(isinstance(morph, gs.morphs.Drone) or (isinstance(morph, gs.morphs.URDF) and not morph.fixed))
Expand Down
2 changes: 1 addition & 1 deletion genesis/engine/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def from_morph_surface(cls, morph, surface=None):
" and rotate glb mesh by default later and gradually enforce this option."
)
elif morph.is_format(gs.options.morphs.USD_FORMATS):
import genesis.utils.usda as usda_utils
import genesis.utils.usd.usda as usda_utils

meshes = usda_utils.parse_mesh_usd(morph.file, morph.group_by_material, morph.scale, surface)
elif isinstance(morph, gs.options.morphs.MeshSet):
Expand Down
13 changes: 12 additions & 1 deletion genesis/engine/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import time
import weakref
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, Literal

import numpy as np
import torch
Expand Down Expand Up @@ -443,6 +443,17 @@ def add_entity(

return entity

@gs.assert_unbuilt
def add_stage(
self,
morph: gs.morphs.USD,
vis_mode: Literal["visual", "collision"] = "visual",
visualize_contact: bool = False,
):
from genesis.utils.usd import import_from_stage

return import_from_stage(self, morph.file, vis_mode, morph, visualize_contact)

@gs.assert_unbuilt
def add_mesh_light(
self,
Expand Down
70 changes: 70 additions & 0 deletions genesis/options/morphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,3 +1284,73 @@ def default_params(self):
@property
def subterrain_params(self):
return self._subterrain_parameters


class USD(FileMorph):
"""
Configuration class for USD file loading with advanced processing options.

This class encapsulates the file path and processing parameters for USD loading,
allowing users to control convexification, decimation, and decomposition behavior
when loading USD scenes via add_stage().

Parameters
----------
file : str
The path to the USD file.
convexify : bool, optional
Whether to convexify the entity. When convexify is True, all the meshes in the entity will each be converted
to a set of convex hulls. The mesh will be decomposed into multiple convex components if the convex hull is not
sufficient to meet the desired accuracy. The module 'coacd' is used for this decomposition process.
If not given, it defaults to `True` for `RigidEntity` and `False` for other deformable entities.
decompose_object_error_threshold : float, optional
For basic rigid objects (mug, table...), skip convex decomposition if the relative difference between the
volume of original mesh and its convex hull is lower than this threshold.
0.0 to enforce decomposition, float("inf") to disable it completely. Defaults to 0.15 (15%).
decompose_robot_error_threshold : float, optional
For poly-articulated robots, skip convex decomposition if the relative difference between the volume of
original mesh and its convex hull is lower than this threshold.
0.0 to enforce decomposition, float("inf") to disable it completely. Defaults to float("inf").
coacd_options : CoacdOptions, optional
Options for configuring coacd convex decomposition. Needs to be a `gs.options.CoacdOptions` object.
decimate : bool, optional
Whether to decimate (simplify) the mesh. Defaults to True. **This is only used for RigidEntity.**
decimate_face_num : int, optional
The number of faces to decimate to. Defaults to 500. **This is only used for RigidEntity.**
decimate_aggressiveness : int, optional
How hard the decimation process will try to match the target number of faces, as an integer ranging from 0 to 8.
0 is lossless. 2 preserves all features of the original geometry. 5 may significantly alter the original
geometry if necessary. 8 does what needs to be done at all costs. Defaults to 2.
**This is only used for RigidEntity.**
_prim_path : str, optional
The parsing target prim path. Defaults to None.
_parsing_type : str, optional
The parsing type.
'articulation' for articulated body parsing, ArticulationRootAPI is required.
'rigid_body' for rigid body parsing, CollisionAPI|RigidBodyAPI is required.
Defaults to None, no parsing will be performed.
"""

file: str
convexify: Optional[bool] = None
decompose_object_error_threshold: float = 0.15
decompose_robot_error_threshold: float = float("inf")
coacd_options: Optional[CoacdOptions] = None
decimate: bool = True
decimate_face_num: int = 500
decimate_aggressiveness: int = 2
prim_path: Optional[str] = None
parsing_type: Optional[Literal["articulation", "rigid_body"]] = None
parser_ctx: Any = None

def __init__(self, **data):
super().__init__(**data)

if not isinstance(self.file, str):
gs.raise_exception("`file` should be a string.")

if not self.file.lower().endswith(USD_FORMATS):
gs.raise_exception(f"USDMorph requires a USD file with extension {USD_FORMATS}, got: {self.file}")

if self.coacd_options is None:
self.coacd_options = CoacdOptions()
108 changes: 108 additions & 0 deletions genesis/utils/geom.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from typing import Literal

import numpy as np
import numba as nb
Expand Down Expand Up @@ -1173,6 +1174,113 @@ def inv_transform_by_T(pos, T):
return transform_by_R(pos - trans, R_inv)


def polar(A, pure_rotation: bool = True, side: Literal["right", "left"] = "right"):
"""
Compute the polar decomposition of a matrix.

Parameters
----------
A : np.ndarray | torch.Tensor
The matrix to decompose. Must be a 2D matrix (M, N), not batched.
pure_rotation : bool, optional
If True, ensure the unitary matrix U has det(U) = 1 (pure rotation).
If False, U may have det(U) = -1 (contains reflection). Default is True.
side : Literal['right', 'left'], optional
The side of the decomposition. Either 'right' or 'left'. Default is 'right'.

Returns
-------
tuple[np.ndarray | torch.Tensor, np.ndarray | torch.Tensor]
A tuple of (U, P) where:
- U : The unitary matrix (rotation part), same shape as A (M, N).
- P : The positive semi-definite matrix (scaling part). For 'right' decomposition,
P has shape (N, N). For 'left' decomposition, P has shape (M, M).
"""
if isinstance(A, torch.Tensor):
if A.ndim != 2:
gs.raise_exception(f"Input must be a 2D matrix. got: {A.ndim=} dimensions")

U_svd, Sigma, Vt = torch.linalg.svd(A, full_matrices=False)

# Handle pure_rotation: if det(U) < 0, flip signs to make it a pure rotation
if pure_rotation and A.shape[0] == A.shape[1]: # Only for square matrices
# Compute U first to check its determinant
U_temp = U_svd @ Vt
det_U = torch.linalg.det(U_temp)
if det_U < 0:
U_svd = U_svd.clone()
Vt = Vt.clone()
Sigma = Sigma.clone()

# Flip the smallest singular value (last one) and corresponding column/row
# This maintains A = U_svd @ diag(Sigma) @ Vt while changing det(U)
if side == "right":
# Flip last singular value and last row of Vt
Sigma[-1] *= -1
Vt[-1, :] *= -1
else:
# For left polar, flip last singular value and last column of U_svd
Sigma[-1] *= -1
U_svd[:, -1] *= -1

U = U_svd @ Vt

# Construct diagonal matrix from singular values
# Use absolute value to ensure P is positive semi-definite
# (even if we flipped Sigma's sign to fix rotation)
Sigma_abs = torch.abs(Sigma)
Sigma_diag = torch.diag_embed(Sigma_abs)
if side == "right":
# P = Vt.T @ diag(|Sigma|) @ Vt
P = Vt.transpose(-2, -1) @ Sigma_diag @ Vt
else:
# P = U_svd @ diag(|Sigma|) @ U_svd.T (left polar: A = P @ U)
P = U_svd @ Sigma_diag @ U_svd.transpose(-2, -1)
return U, P
elif isinstance(A, np.ndarray):
if A.ndim != 2:
gs.raise_exception(f"Input must be a 2D matrix. got: {A.ndim=} dimensions")

U_svd, Sigma, Vt = np.linalg.svd(A, full_matrices=False)

# Handle pure_rotation: if det(U) < 0, flip signs to make it a pure rotation
if pure_rotation and A.shape[0] == A.shape[1]: # Only for square matrices
# Compute U first to check its determinant
U_temp = U_svd @ Vt
det_U = np.linalg.det(U_temp)
if det_U < 0:
U_svd = U_svd.copy()
Vt = Vt.copy()
Sigma = Sigma.copy()

# Flip the smallest singular value (last one) and corresponding column/row
# This maintains A = U_svd @ diag(Sigma) @ Vt while changing det(U)
if side == "right":
# Flip last singular value and last row of Vt
Sigma[-1] *= -1
Vt[-1, :] *= -1
else:
# For left polar, flip last singular value and last column of U_svd
Sigma[-1] *= -1
U_svd[:, -1] *= -1

U = U_svd @ Vt

# Use absolute value to ensure P is positive semi-definite
# (even if we flipped Sigma's sign to fix rotation)
Sigma_abs = np.abs(Sigma)

if side == "right":
# P = Vt.T @ diag(|Sigma|) @ Vt
P = Vt.T @ np.diag(Sigma_abs) @ Vt
else:
# P = U_svd @ diag(|Sigma|) @ U_svd.T (left polar: A = P @ U)
P = U_svd @ np.diag(Sigma_abs) @ U_svd.T
return U, P
else:
gs.raise_exception(f"the input must be either torch.Tensor or np.ndarray. got: {type(A)=}")


# ------------------------------------------------------------------------------------
# ------------------------------------- numpy ----------------------------------------
# ------------------------------------------------------------------------------------
Expand Down
Loading