Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions mujoco_warp/_src/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,36 +72,38 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
warp_util.check_toolkit_driver()

# model: check supported features in array types
for field, field_type in (
(mjm.actuator_trntype, types.TrnType),
(mjm.actuator_dyntype, types.DynType),
(mjm.actuator_gaintype, types.GainType),
(mjm.actuator_biastype, types.BiasType),
(mjm.eq_type, types.EqType),
(mjm.geom_type, types.GeomType),
(mjm.sensor_type, types.SensorType),
(mjm.wrap_type, types.WrapType),
for field, field_type, mj_type in (
(mjm.actuator_trntype, types.TrnType, mujoco.mjtTrn),
(mjm.actuator_dyntype, types.DynType, mujoco.mjtDyn),
(mjm.actuator_gaintype, types.GainType, mujoco.mjtGain),
(mjm.actuator_biastype, types.BiasType, mujoco.mjtBias),
(mjm.eq_type, types.EqType, mujoco.mjtEq),
(mjm.geom_type, types.GeomType, mujoco.mjtGeom),
(mjm.sensor_type, types.SensorType, mujoco.mjtSensor),
(mjm.wrap_type, types.WrapType, mujoco.mjtWrap),
):
missing = ~np.isin(field, field_type)
if missing.any():
raise NotImplementedError(f"{field_type.__name__}: {field[missing]} not supported.")
names = [mj_type(v).name for v in field[missing]]
raise NotImplementedError(f"{names} not supported.")

# opt: check supported features in scalar types
for field, field_type in (
(mjm.opt.integrator, types.IntegratorType),
(mjm.opt.cone, types.ConeType),
(mjm.opt.solver, types.SolverType),
for field, field_type, mj_type in (
(mjm.opt.integrator, types.IntegratorType, mujoco.mjtIntegrator),
(mjm.opt.cone, types.ConeType, mujoco.mjtCone),
(mjm.opt.solver, types.SolverType, mujoco.mjtSolver),
):
if field not in set(field_type):
raise NotImplementedError(f"{field_type.__name__} {field} is unsupported.")
raise NotImplementedError(f"{mj_type(field).name} is unsupported.")

# opt: check supported features in scalar flag types
for field, field_type in (
(mjm.opt.disableflags, types.DisableBit),
(mjm.opt.enableflags, types.EnableBit),
for field, field_type, mj_type in (
(mjm.opt.disableflags, types.DisableBit, mujoco.mjtDisableBit),
(mjm.opt.enableflags, types.EnableBit, mujoco.mjtEnableBit),
):
if field & ~np.bitwise_or.reduce(field_type):
raise NotImplementedError(f"{field_type.__name__} {field} is unsupported.")
unsupported = field & ~np.bitwise_or.reduce(field_type)
if unsupported:
raise NotImplementedError(f"{mj_type(unsupported).name} is unsupported.")

if ((mjm.flex_contype != 0) | (mjm.flex_conaffinity != 0)).any():
raise NotImplementedError("Flex collisions are not implemented.")
Expand Down
Loading