Skip to content

Commit ed56c94

Browse files
committed
improve io.put_model NotImplementedError messages
1 parent f0baf70 commit ed56c94

1 file changed

Lines changed: 22 additions & 20 deletions

File tree

mujoco_warp/_src/io.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -72,36 +72,38 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
7272
warp_util.check_toolkit_driver()
7373

7474
# model: check supported features in array types
75-
for field, field_type in (
76-
(mjm.actuator_trntype, types.TrnType),
77-
(mjm.actuator_dyntype, types.DynType),
78-
(mjm.actuator_gaintype, types.GainType),
79-
(mjm.actuator_biastype, types.BiasType),
80-
(mjm.eq_type, types.EqType),
81-
(mjm.geom_type, types.GeomType),
82-
(mjm.sensor_type, types.SensorType),
83-
(mjm.wrap_type, types.WrapType),
75+
for field, field_type, mj_type in (
76+
(mjm.actuator_trntype, types.TrnType, mujoco.mjtTrn),
77+
(mjm.actuator_dyntype, types.DynType, mujoco.mjtDyn),
78+
(mjm.actuator_gaintype, types.GainType, mujoco.mjtGain),
79+
(mjm.actuator_biastype, types.BiasType, mujoco.mjtBias),
80+
(mjm.eq_type, types.EqType, mujoco.mjtEq),
81+
(mjm.geom_type, types.GeomType, mujoco.mjtGeom),
82+
(mjm.sensor_type, types.SensorType, mujoco.mjtSensor),
83+
(mjm.wrap_type, types.WrapType, mujoco.mjtWrap),
8484
):
8585
missing = ~np.isin(field, field_type)
8686
if missing.any():
87-
raise NotImplementedError(f"{field_type.__name__}: {field[missing]} not supported.")
87+
names = [mj_type(v).name for v in field[missing]]
88+
raise NotImplementedError(f"{names} not supported.")
8889

8990
# opt: check supported features in scalar types
90-
for field, field_type in (
91-
(mjm.opt.integrator, types.IntegratorType),
92-
(mjm.opt.cone, types.ConeType),
93-
(mjm.opt.solver, types.SolverType),
91+
for field, field_type, mj_type in (
92+
(mjm.opt.integrator, types.IntegratorType, mujoco.mjtIntegrator),
93+
(mjm.opt.cone, types.ConeType, mujoco.mjtCone),
94+
(mjm.opt.solver, types.SolverType, mujoco.mjtSolver),
9495
):
9596
if field not in set(field_type):
96-
raise NotImplementedError(f"{field_type.__name__} {field} is unsupported.")
97+
raise NotImplementedError(f"{mj_type(field).name} is unsupported.")
9798

9899
# opt: check supported features in scalar flag types
99-
for field, field_type in (
100-
(mjm.opt.disableflags, types.DisableBit),
101-
(mjm.opt.enableflags, types.EnableBit),
100+
for field, field_type, mj_type in (
101+
(mjm.opt.disableflags, types.DisableBit, mujoco.mjtDisableBit),
102+
(mjm.opt.enableflags, types.EnableBit, mujoco.mjtEnableBit),
102103
):
103-
if field & ~np.bitwise_or.reduce(field_type):
104-
raise NotImplementedError(f"{field_type.__name__} {field} is unsupported.")
104+
unsupported = field & ~np.bitwise_or.reduce(field_type)
105+
if unsupported:
106+
raise NotImplementedError(f"{mj_type(unsupported).name} is unsupported.")
105107

106108
if ((mjm.flex_contype != 0) | (mjm.flex_conaffinity != 0)).any():
107109
raise NotImplementedError("Flex collisions are not implemented.")

0 commit comments

Comments
 (0)