@@ -72,36 +72,38 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
7272 warp_util .check_toolkit_driver ()
7373
7474 # model: check supported features in array types
75- for field , field_type in (
76- (mjm .actuator_trntype , types .TrnType ),
77- (mjm .actuator_dyntype , types .DynType ),
78- (mjm .actuator_gaintype , types .GainType ),
79- (mjm .actuator_biastype , types .BiasType ),
80- (mjm .eq_type , types .EqType ),
81- (mjm .geom_type , types .GeomType ),
82- (mjm .sensor_type , types .SensorType ),
83- (mjm .wrap_type , types .WrapType ),
75+ for field , field_type , mj_type in (
76+ (mjm .actuator_trntype , types .TrnType , mujoco . mjtTrn ),
77+ (mjm .actuator_dyntype , types .DynType , mujoco . mjtDyn ),
78+ (mjm .actuator_gaintype , types .GainType , mujoco . mjtGain ),
79+ (mjm .actuator_biastype , types .BiasType , mujoco . mjtBias ),
80+ (mjm .eq_type , types .EqType , mujoco . mjtEq ),
81+ (mjm .geom_type , types .GeomType , mujoco . mjtGeom ),
82+ (mjm .sensor_type , types .SensorType , mujoco . mjtSensor ),
83+ (mjm .wrap_type , types .WrapType , mujoco . mjtWrap ),
8484 ):
8585 missing = ~ np .isin (field , field_type )
8686 if missing .any ():
87- raise NotImplementedError (f"{ field_type .__name__ } : { field [missing ]} not supported." )
87+ names = [mj_type (v ).name for v in field [missing ]]
88+ raise NotImplementedError (f"{ names } not supported." )
8889
8990 # opt: check supported features in scalar types
90- for field , field_type in (
91- (mjm .opt .integrator , types .IntegratorType ),
92- (mjm .opt .cone , types .ConeType ),
93- (mjm .opt .solver , types .SolverType ),
91+ for field , field_type , mj_type in (
92+ (mjm .opt .integrator , types .IntegratorType , mujoco . mjtIntegrator ),
93+ (mjm .opt .cone , types .ConeType , mujoco . mjtCone ),
94+ (mjm .opt .solver , types .SolverType , mujoco . mjtSolver ),
9495 ):
9596 if field not in set (field_type ):
96- raise NotImplementedError (f"{ field_type . __name__ } { field } is unsupported." )
97+ raise NotImplementedError (f"{ mj_type ( field ). name } is unsupported." )
9798
9899 # opt: check supported features in scalar flag types
99- for field , field_type in (
100- (mjm .opt .disableflags , types .DisableBit ),
101- (mjm .opt .enableflags , types .EnableBit ),
100+ for field , field_type , mj_type in (
101+ (mjm .opt .disableflags , types .DisableBit , mujoco . mjtDisableBit ),
102+ (mjm .opt .enableflags , types .EnableBit , mujoco . mjtEnableBit ),
102103 ):
103- if field & ~ np .bitwise_or .reduce (field_type ):
104- raise NotImplementedError (f"{ field_type .__name__ } { field } is unsupported." )
104+ unsupported = field & ~ np .bitwise_or .reduce (field_type )
105+ if unsupported :
106+ raise NotImplementedError (f"{ mj_type (unsupported ).name } is unsupported." )
105107
106108 if ((mjm .flex_contype != 0 ) | (mjm .flex_conaffinity != 0 )).any ():
107109 raise NotImplementedError ("Flex collisions are not implemented." )
0 commit comments