diff --git a/mujoco_warp/_src/collision_sdf.py b/mujoco_warp/_src/collision_sdf.py index 0842421a0..1c43390c4 100644 --- a/mujoco_warp/_src/collision_sdf.py +++ b/mujoco_warp/_src/collision_sdf.py @@ -28,6 +28,7 @@ from .types import vec5 from .types import vec8 from .types import vec8i +from .types import vec_pluginattr from .util_misc import halton from .warp_util import event_scope @@ -38,8 +39,8 @@ class OptimizationParams: rel_mat: wp.mat33 rel_pos: wp.vec3 - attr1: wp.vec3 - attr2: wp.vec3 + attr1: vec_pluginattr + attr2: vec_pluginattr @wp.struct @@ -81,14 +82,18 @@ def get_sdf_params( oct_aabb: wp.array2d(dtype=wp.vec3), oct_coeff: wp.array(dtype=vec8), plugin: wp.array(dtype=int), - plugin_attr: wp.array(dtype=wp.vec3f), + plugin_attr: wp.array(dtype=vec_pluginattr), # In: g_type: int, g_size: wp.vec3, plugin_id: int, mesh_id: int, -) -> Tuple[wp.vec3, int, VolumeData, MeshData]: - attributes = g_size +) -> Tuple[vec_pluginattr, int, VolumeData, MeshData]: + # default attributes from geom size, first 3 values copied + attributes = vec_pluginattr() + attributes[0] = g_size[0] + attributes[1] = g_size[1] + attributes[2] = g_size[2] plugin_index = -1 volume_data = VolumeData() @@ -211,13 +216,21 @@ def grad_ellipsoid(p: wp.vec3, size: wp.vec3) -> wp.vec3: @wp.func -def user_sdf(p: wp.vec3, attr: wp.vec3, sdf_type: int) -> float: +def user_sdf(p: wp.vec3, attr: vec_pluginattr, sdf_type: int) -> float: + """User-defined SDF function. + + Access attributes via attr[i] where i is the attribute index (0 to _NPLUGINATTR-1). + """ wp.printf("ERROR: user_sdf function must be implemented by user code\n") return 0.0 @wp.func -def user_sdf_grad(p: wp.vec3, attr: wp.vec3, sdf_type: int) -> wp.vec3: +def user_sdf_grad(p: wp.vec3, attr: vec_pluginattr, sdf_type: int) -> wp.vec3: + """User-defined SDF gradient function. + + Access attributes via attr[i] where i is the attribute index (0 to _NPLUGINATTR-1). + """ wp.printf("ERROR: user_sdf_grad function must be implemented by user code\n") return wp.vec3(0.0) @@ -355,15 +368,17 @@ def sample_volume_grad(xyz: wp.vec3, volume_data: VolumeData) -> wp.vec3: @wp.func -def sdf(type: int, p: wp.vec3, attr: wp.vec3, sdf_type: int, volume_data: VolumeData, mesh_data: MeshData) -> float: +def sdf(type: int, p: wp.vec3, attr: vec_pluginattr, sdf_type: int, volume_data: VolumeData, mesh_data: MeshData) -> float: + # extract first 3 elements as vec3 for primitive sdf functions + attr_vec3 = wp.vec3(attr[0], attr[1], attr[2]) if type == GeomType.PLANE: return p[2] elif type == GeomType.SPHERE: - return sphere(p, attr) + return sphere(p, attr_vec3) elif type == GeomType.BOX: - return box(p, attr) + return box(p, attr_vec3) elif type == GeomType.ELLIPSOID: - return ellipsoid(p, attr) + return ellipsoid(p, attr_vec3) elif type == GeomType.MESH and mesh_data.valid: mesh_data.pnt = p mesh_data.vec = -wp.normalize(p) @@ -404,16 +419,20 @@ def sdf(type: int, p: wp.vec3, attr: wp.vec3, sdf_type: int, volume_data: Volume @wp.func -def sdf_grad(type: int, p: wp.vec3, attr: wp.vec3, sdf_type: int, volume_data: VolumeData, mesh_data: MeshData) -> wp.vec3: +def sdf_grad( + type: int, p: wp.vec3, attr: vec_pluginattr, sdf_type: int, volume_data: VolumeData, mesh_data: MeshData +) -> wp.vec3: + # extract first 3 elements as vec3 for primitive sdf functions + attr_vec3 = wp.vec3(attr[0], attr[1], attr[2]) if type == GeomType.PLANE: grad = wp.vec3(0.0, 0.0, 1.0) return grad elif type == GeomType.SPHERE: return grad_sphere(p) elif type == GeomType.BOX: - return grad_box(p, attr) + return grad_box(p, attr_vec3) elif type == GeomType.ELLIPSOID: - return grad_ellipsoid(p, attr) + return grad_ellipsoid(p, attr_vec3) elif type == GeomType.MESH and mesh_data.valid: mesh_data.pnt = p mesh_data.vec = -wp.normalize(p) @@ -449,8 +468,8 @@ def clearance( type1: int, p1: wp.vec3, p2: wp.vec3, - s1: wp.vec3, - s2: wp.vec3, + s1: vec_pluginattr, + s2: vec_pluginattr, sdf_type1: int, sdf_type2: int, sfd_intersection: bool, @@ -579,8 +598,8 @@ def gradient_descent( # In: type1: int, x0_initial: wp.vec3, - attr1: wp.vec3, - attr2: wp.vec3, + attr1: vec_pluginattr, + attr2: vec_pluginattr, pos1: wp.vec3, rot1: wp.mat33, pos2: wp.vec3, @@ -659,7 +678,7 @@ def _sdf_narrowphase( pair_gap: wp.array2d(dtype=float), pair_friction: wp.array2d(dtype=vec5), plugin: wp.array(dtype=int), - plugin_attr: wp.array(dtype=wp.vec3f), + plugin_attr: wp.array(dtype=vec_pluginattr), geom_plugin_index: wp.array(dtype=int), # Data in: geom_xpos_in: wp.array2d(dtype=wp.vec3), diff --git a/mujoco_warp/_src/io.py b/mujoco_warp/_src/io.py index e2ed2d238..39d27db74 100644 --- a/mujoco_warp/_src/io.py +++ b/mujoco_warp/_src/io.py @@ -363,9 +363,11 @@ def geom_trid_index(i, j): current = [] else: current.append(v) - # Pad with zeros if less than 3 - attr_values += [0.0] * (3 - len(attr_values)) - m.plugin_attr.append(attr_values[:3]) + if len(attr_values) > types._NPLUGINATTR: + raise ValueError(f"Plugin has {len(attr_values)} attributes, which exceeds the maximum of {types._NPLUGINATTR}. ") + # pad with zeros to _NPLUGINATTR + attr_values += [0.0] * (types._NPLUGINATTR - len(attr_values)) + m.plugin_attr.append(attr_values[: types._NPLUGINATTR]) # equality constraint addresses m.eq_connect_adr = np.nonzero(mjm.eq_type == types.EqType.CONNECT)[0] diff --git a/mujoco_warp/_src/sensor.py b/mujoco_warp/_src/sensor.py index 431ad9d51..304b7cc0d 100644 --- a/mujoco_warp/_src/sensor.py +++ b/mujoco_warp/_src/sensor.py @@ -39,6 +39,7 @@ from .types import vec6 from .types import vec8 from .types import vec8i +from .types import vec_pluginattr from .util_misc import inside_geom from .warp_util import cache_kernel from .warp_util import event_scope @@ -2106,7 +2107,7 @@ def _sensor_tactile( sensor_dim: wp.array(dtype=int), sensor_adr: wp.array(dtype=int), plugin: wp.array(dtype=int), - plugin_attr: wp.array(dtype=wp.vec3f), + plugin_attr: wp.array(dtype=vec_pluginattr), geom_plugin_index: wp.array(dtype=int), taxel_vertadr: wp.array(dtype=int), taxel_sensorid: wp.array(dtype=int), diff --git a/mujoco_warp/_src/types.py b/mujoco_warp/_src/types.py index 6ff16d703..d3540066c 100644 --- a/mujoco_warp/_src/types.py +++ b/mujoco_warp/_src/types.py @@ -32,6 +32,9 @@ TILE_SIZE_JTDAJ_SPARSE = 16 TILE_SIZE_JTDAJ_DENSE = 16 +# maximum number of plugin attributes +_NPLUGINATTR = 128 + # TODO(team): add check that all wp.launch_tiled 'block_dim' settings are configurable @dataclasses.dataclass @@ -600,6 +603,10 @@ class vec11f(wp.types.vector(length=11, dtype=float)): pass +class vec_pluginattr(wp.types.vector(length=_NPLUGINATTR, dtype=float)): + pass + + class mat23f(wp.types.matrix(shape=(2, 3), dtype=float)): pass @@ -617,6 +624,7 @@ class mat63f(wp.types.matrix(shape=(6, 3), dtype=float)): vec8 = vec8f vec10 = vec10f vec11 = vec11f +vec128 = vec_pluginattr mat23 = mat23f mat43 = mat43f mat63 = mat63f @@ -999,7 +1007,7 @@ class Model: sensor_adr: address in sensor array (nsensor,) sensor_cutoff: cutoff for real and positive; 0: ignore (nsensor,) plugin: globally registered plugin slot number (nplugin,) - plugin_attr: config attributes of geom plugin (nplugin, 3) + plugin_attr: config attributes of geom plugin (nplugin, _NPLUGINATTR) M_rownnz: number of non-zeros in each row of qM (nv,) M_rowadr: index of each row in qM (nv,) M_colind: column indices of non-zeros in qM (nM,) @@ -1347,7 +1355,7 @@ class Model: sensor_adr: array("nsensor", int) sensor_cutoff: array("nsensor", float) plugin: array("nplugin", int) - plugin_attr: array("nplugin", wp.vec3f) + plugin_attr: array("nplugin", vec_pluginattr) M_rownnz: array("nv", int) M_rowadr: array("nv", int) M_colind: array("nC", int) diff --git a/mujoco_warp/test_data/collision_sdf/bolt.py b/mujoco_warp/test_data/collision_sdf/bolt.py index 07a99a1e0..9b80926f7 100644 --- a/mujoco_warp/test_data/collision_sdf/bolt.py +++ b/mujoco_warp/test_data/collision_sdf/bolt.py @@ -1,5 +1,7 @@ import warp as wp +from mujoco_warp._src.types import vec_pluginattr + @wp.func def Fract(x: float) -> float: @@ -22,7 +24,7 @@ def Intersection(a: float, b: float) -> float: @wp.func -def bolt(p: wp.vec3, attr: wp.vec3) -> float: +def bolt(p: wp.vec3, attr: vec_pluginattr) -> float: screw = 12.0 radius = wp.sqrt(p[0] * p[0] + p[1] * p[1]) - attr[0] sqrt12 = wp.sqrt(2.0) / 2.0 @@ -52,7 +54,7 @@ def bolt(p: wp.vec3, attr: wp.vec3) -> float: @wp.func -def bolt_sdf_grad(p: wp.vec3, attr: wp.vec3) -> wp.vec3: +def bolt_sdf_grad(p: wp.vec3, attr: vec_pluginattr) -> wp.vec3: grad = wp.vec3() eps = 1e-6 f_original = bolt(p, attr) diff --git a/mujoco_warp/test_data/collision_sdf/gear.py b/mujoco_warp/test_data/collision_sdf/gear.py index c505f2644..bc4ea227d 100644 --- a/mujoco_warp/test_data/collision_sdf/gear.py +++ b/mujoco_warp/test_data/collision_sdf/gear.py @@ -1,5 +1,7 @@ import warp as wp +from mujoco_warp._src.types import vec_pluginattr + @wp.func def Subtraction(a: float, b: float) -> float: @@ -44,7 +46,7 @@ def mod(x: float, y: float) -> float: @wp.func -def distance2D(p: wp.vec3, attributes: wp.vec3) -> float: +def distance2D(p: wp.vec3, attributes: vec_pluginattr) -> float: # see https://www.shadertoy.com/view/3lG3WR D = 2.8 N = 25.0 @@ -120,13 +122,13 @@ def distance2D(p: wp.vec3, attributes: wp.vec3) -> float: @wp.func -def gear(p: wp.vec3, attr: wp.vec3) -> float: +def gear(p: wp.vec3, attr: vec_pluginattr) -> float: thickness = 0.2 return extrusion(p, distance2D(p, attr), thickness / 2.0) @wp.func -def gear_sdf_grad(p: wp.vec3, attr: wp.vec3) -> wp.vec3: +def gear_sdf_grad(p: wp.vec3, attr: vec_pluginattr) -> wp.vec3: grad = wp.vec3() eps = 1e-6 f_original = gear(p, attr) diff --git a/mujoco_warp/test_data/collision_sdf/nut.py b/mujoco_warp/test_data/collision_sdf/nut.py index 47b8cc736..452842ae0 100644 --- a/mujoco_warp/test_data/collision_sdf/nut.py +++ b/mujoco_warp/test_data/collision_sdf/nut.py @@ -1,5 +1,7 @@ import warp as wp +from mujoco_warp._src.types import vec_pluginattr + @wp.func def Fract(x: float) -> float: @@ -22,7 +24,7 @@ def Intersection(a: float, b: float) -> float: @wp.func -def nut(p: wp.vec3, attr: wp.vec3) -> float: +def nut(p: wp.vec3, attr: vec_pluginattr) -> float: screw = 12.0 radius2 = wp.sqrt(p[0] * p[0] + p[1] * p[1]) - attr[0] sqrt12 = wp.sqrt(2.0) / 2.0 @@ -47,7 +49,7 @@ def nut(p: wp.vec3, attr: wp.vec3) -> float: @wp.func -def nut_sdf_grad(p: wp.vec3, attr: wp.vec3) -> wp.vec3: +def nut_sdf_grad(p: wp.vec3, attr: vec_pluginattr) -> wp.vec3: grad = wp.vec3() eps = 1e-6 f_original = nut(p, attr) diff --git a/mujoco_warp/test_data/collision_sdf/torus.py b/mujoco_warp/test_data/collision_sdf/torus.py index 2cc05e892..cfd514731 100644 --- a/mujoco_warp/test_data/collision_sdf/torus.py +++ b/mujoco_warp/test_data/collision_sdf/torus.py @@ -2,9 +2,11 @@ import warp as wp +from mujoco_warp._src.types import vec_pluginattr + @wp.func -def torus(p: wp.vec3, attr: wp.vec3) -> wp.float32: +def torus(p: wp.vec3, attr: vec_pluginattr) -> wp.float32: major_radius = attr[0] minor_radius = attr[1] @@ -14,7 +16,7 @@ def torus(p: wp.vec3, attr: wp.vec3) -> wp.float32: @wp.func -def torus_sdf_grad(p: wp.vec3, attr: wp.vec3) -> wp.vec3: +def torus_sdf_grad(p: wp.vec3, attr: vec_pluginattr) -> wp.vec3: grad = wp.vec3() major_radius = attr[0] minor_val = attr[1] diff --git a/mujoco_warp/test_data/collision_sdf/utils.py b/mujoco_warp/test_data/collision_sdf/utils.py index 741e7c9ed..b9759971a 100644 --- a/mujoco_warp/test_data/collision_sdf/utils.py +++ b/mujoco_warp/test_data/collision_sdf/utils.py @@ -20,6 +20,8 @@ import mujoco import warp as wp +from mujoco_warp._src.types import vec_pluginattr + from .bolt import bolt from .bolt import bolt_sdf_grad from .gear import gear @@ -81,7 +83,7 @@ def register_sdf_plugins(mjwarp) -> Dict[str, int]: sdf_types[SDFType.GEAR] = int(m.plugin[i]) @wp.func - def user_sdf(p: wp.vec3, attr: wp.vec3, sdf_type: int) -> float: + def user_sdf(p: wp.vec3, attr: vec_pluginattr, sdf_type: int) -> float: result = 0.0 if sdf_type == wp.static(sdf_types[SDFType.NUT]): result = nut(p, attr) @@ -94,7 +96,7 @@ def user_sdf(p: wp.vec3, attr: wp.vec3, sdf_type: int) -> float: return result @wp.func - def user_sdf_grad(p: wp.vec3, attr: wp.vec3, sdf_type: int) -> wp.vec3: + def user_sdf_grad(p: wp.vec3, attr: vec_pluginattr, sdf_type: int) -> wp.vec3: if sdf_type == wp.static(sdf_types[SDFType.NUT]): return nut_sdf_grad(p, attr) elif sdf_type == wp.static(sdf_types[SDFType.BOLT]):