Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions mujoco_warp/_src/collision_sdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .types import vec5
from .types import vec8
from .types import vec8i
from .types import vec_pluginattr
from .util_misc import halton
from .warp_util import event_scope

Expand All @@ -38,8 +39,8 @@
class OptimizationParams:
rel_mat: wp.mat33
rel_pos: wp.vec3
attr1: wp.vec3
attr2: wp.vec3
attr1: vec_pluginattr
attr2: vec_pluginattr


@wp.struct
Expand Down Expand Up @@ -81,14 +82,18 @@ def get_sdf_params(
oct_aabb: wp.array2d(dtype=wp.vec3),
oct_coeff: wp.array(dtype=vec8),
plugin: wp.array(dtype=int),
plugin_attr: wp.array(dtype=wp.vec3f),
plugin_attr: wp.array(dtype=vec_pluginattr),
# In:
g_type: int,
g_size: wp.vec3,
plugin_id: int,
mesh_id: int,
) -> Tuple[wp.vec3, int, VolumeData, MeshData]:
attributes = g_size
) -> Tuple[vec_pluginattr, int, VolumeData, MeshData]:
# default attributes from geom size, first 3 values copied
attributes = vec_pluginattr()
attributes[0] = g_size[0]
attributes[1] = g_size[1]
attributes[2] = g_size[2]
plugin_index = -1
volume_data = VolumeData()

Expand Down Expand Up @@ -211,13 +216,21 @@ def grad_ellipsoid(p: wp.vec3, size: wp.vec3) -> wp.vec3:


@wp.func
def user_sdf(p: wp.vec3, attr: wp.vec3, sdf_type: int) -> float:
def user_sdf(p: wp.vec3, attr: vec_pluginattr, sdf_type: int) -> float:
"""User-defined SDF function.

Access attributes via attr[i] where i is the attribute index (0 to _NPLUGINATTR-1).
"""
wp.printf("ERROR: user_sdf function must be implemented by user code\n")
return 0.0


@wp.func
def user_sdf_grad(p: wp.vec3, attr: wp.vec3, sdf_type: int) -> wp.vec3:
def user_sdf_grad(p: wp.vec3, attr: vec_pluginattr, sdf_type: int) -> wp.vec3:
"""User-defined SDF gradient function.

Access attributes via attr[i] where i is the attribute index (0 to _NPLUGINATTR-1).
"""
wp.printf("ERROR: user_sdf_grad function must be implemented by user code\n")
return wp.vec3(0.0)

Expand Down Expand Up @@ -355,15 +368,17 @@ def sample_volume_grad(xyz: wp.vec3, volume_data: VolumeData) -> wp.vec3:


@wp.func
def sdf(type: int, p: wp.vec3, attr: wp.vec3, sdf_type: int, volume_data: VolumeData, mesh_data: MeshData) -> float:
def sdf(type: int, p: wp.vec3, attr: vec_pluginattr, sdf_type: int, volume_data: VolumeData, mesh_data: MeshData) -> float:
# extract first 3 elements as vec3 for primitive sdf functions
attr_vec3 = wp.vec3(attr[0], attr[1], attr[2])
if type == GeomType.PLANE:
return p[2]
elif type == GeomType.SPHERE:
return sphere(p, attr)
return sphere(p, attr_vec3)
elif type == GeomType.BOX:
return box(p, attr)
return box(p, attr_vec3)
elif type == GeomType.ELLIPSOID:
return ellipsoid(p, attr)
return ellipsoid(p, attr_vec3)
elif type == GeomType.MESH and mesh_data.valid:
mesh_data.pnt = p
mesh_data.vec = -wp.normalize(p)
Expand Down Expand Up @@ -404,16 +419,20 @@ def sdf(type: int, p: wp.vec3, attr: wp.vec3, sdf_type: int, volume_data: Volume


@wp.func
def sdf_grad(type: int, p: wp.vec3, attr: wp.vec3, sdf_type: int, volume_data: VolumeData, mesh_data: MeshData) -> wp.vec3:
def sdf_grad(
type: int, p: wp.vec3, attr: vec_pluginattr, sdf_type: int, volume_data: VolumeData, mesh_data: MeshData
) -> wp.vec3:
# extract first 3 elements as vec3 for primitive sdf functions
attr_vec3 = wp.vec3(attr[0], attr[1], attr[2])
if type == GeomType.PLANE:
grad = wp.vec3(0.0, 0.0, 1.0)
return grad
elif type == GeomType.SPHERE:
return grad_sphere(p)
elif type == GeomType.BOX:
return grad_box(p, attr)
return grad_box(p, attr_vec3)
elif type == GeomType.ELLIPSOID:
return grad_ellipsoid(p, attr)
return grad_ellipsoid(p, attr_vec3)
elif type == GeomType.MESH and mesh_data.valid:
mesh_data.pnt = p
mesh_data.vec = -wp.normalize(p)
Expand Down Expand Up @@ -449,8 +468,8 @@ def clearance(
type1: int,
p1: wp.vec3,
p2: wp.vec3,
s1: wp.vec3,
s2: wp.vec3,
s1: vec_pluginattr,
s2: vec_pluginattr,
sdf_type1: int,
sdf_type2: int,
sfd_intersection: bool,
Expand Down Expand Up @@ -579,8 +598,8 @@ def gradient_descent(
# In:
type1: int,
x0_initial: wp.vec3,
attr1: wp.vec3,
attr2: wp.vec3,
attr1: vec_pluginattr,
attr2: vec_pluginattr,
pos1: wp.vec3,
rot1: wp.mat33,
pos2: wp.vec3,
Expand Down Expand Up @@ -659,7 +678,7 @@ def _sdf_narrowphase(
pair_gap: wp.array2d(dtype=float),
pair_friction: wp.array2d(dtype=vec5),
plugin: wp.array(dtype=int),
plugin_attr: wp.array(dtype=wp.vec3f),
plugin_attr: wp.array(dtype=vec_pluginattr),
geom_plugin_index: wp.array(dtype=int),
# Data in:
geom_xpos_in: wp.array2d(dtype=wp.vec3),
Expand Down
8 changes: 5 additions & 3 deletions mujoco_warp/_src/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,11 @@ def geom_trid_index(i, j):
current = []
else:
current.append(v)
# Pad with zeros if less than 3
attr_values += [0.0] * (3 - len(attr_values))
m.plugin_attr.append(attr_values[:3])
if len(attr_values) > types._NPLUGINATTR:
raise ValueError(f"Plugin has {len(attr_values)} attributes, which exceeds the maximum of {types._NPLUGINATTR}. ")
# pad with zeros to _NPLUGINATTR
attr_values += [0.0] * (types._NPLUGINATTR - len(attr_values))
m.plugin_attr.append(attr_values[: types._NPLUGINATTR])

# equality constraint addresses
m.eq_connect_adr = np.nonzero(mjm.eq_type == types.EqType.CONNECT)[0]
Expand Down
3 changes: 2 additions & 1 deletion mujoco_warp/_src/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .types import vec6
from .types import vec8
from .types import vec8i
from .types import vec_pluginattr
from .util_misc import inside_geom
from .warp_util import cache_kernel
from .warp_util import event_scope
Expand Down Expand Up @@ -2106,7 +2107,7 @@ def _sensor_tactile(
sensor_dim: wp.array(dtype=int),
sensor_adr: wp.array(dtype=int),
plugin: wp.array(dtype=int),
plugin_attr: wp.array(dtype=wp.vec3f),
plugin_attr: wp.array(dtype=vec_pluginattr),
geom_plugin_index: wp.array(dtype=int),
taxel_vertadr: wp.array(dtype=int),
taxel_sensorid: wp.array(dtype=int),
Expand Down
12 changes: 10 additions & 2 deletions mujoco_warp/_src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
TILE_SIZE_JTDAJ_SPARSE = 16
TILE_SIZE_JTDAJ_DENSE = 16

# maximum number of plugin attributes
_NPLUGINATTR = 128


# TODO(team): add check that all wp.launch_tiled 'block_dim' settings are configurable
@dataclasses.dataclass
Expand Down Expand Up @@ -600,6 +603,10 @@ class vec11f(wp.types.vector(length=11, dtype=float)):
pass


class vec_pluginattr(wp.types.vector(length=_NPLUGINATTR, dtype=float)):
pass


class mat23f(wp.types.matrix(shape=(2, 3), dtype=float)):
pass

Expand All @@ -617,6 +624,7 @@ class mat63f(wp.types.matrix(shape=(6, 3), dtype=float)):
vec8 = vec8f
vec10 = vec10f
vec11 = vec11f
vec128 = vec_pluginattr
mat23 = mat23f
mat43 = mat43f
mat63 = mat63f
Expand Down Expand Up @@ -999,7 +1007,7 @@ class Model:
sensor_adr: address in sensor array (nsensor,)
sensor_cutoff: cutoff for real and positive; 0: ignore (nsensor,)
plugin: globally registered plugin slot number (nplugin,)
plugin_attr: config attributes of geom plugin (nplugin, 3)
plugin_attr: config attributes of geom plugin (nplugin, _NPLUGINATTR)
M_rownnz: number of non-zeros in each row of qM (nv,)
M_rowadr: index of each row in qM (nv,)
M_colind: column indices of non-zeros in qM (nM,)
Expand Down Expand Up @@ -1347,7 +1355,7 @@ class Model:
sensor_adr: array("nsensor", int)
sensor_cutoff: array("nsensor", float)
plugin: array("nplugin", int)
plugin_attr: array("nplugin", wp.vec3f)
plugin_attr: array("nplugin", vec_pluginattr)
M_rownnz: array("nv", int)
M_rowadr: array("nv", int)
M_colind: array("nC", int)
Expand Down
6 changes: 4 additions & 2 deletions mujoco_warp/test_data/collision_sdf/bolt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import warp as wp

from mujoco_warp._src.types import vec_pluginattr


@wp.func
def Fract(x: float) -> float:
Expand All @@ -22,7 +24,7 @@ def Intersection(a: float, b: float) -> float:


@wp.func
def bolt(p: wp.vec3, attr: wp.vec3) -> float:
def bolt(p: wp.vec3, attr: vec_pluginattr) -> float:
screw = 12.0
radius = wp.sqrt(p[0] * p[0] + p[1] * p[1]) - attr[0]
sqrt12 = wp.sqrt(2.0) / 2.0
Expand Down Expand Up @@ -52,7 +54,7 @@ def bolt(p: wp.vec3, attr: wp.vec3) -> float:


@wp.func
def bolt_sdf_grad(p: wp.vec3, attr: wp.vec3) -> wp.vec3:
def bolt_sdf_grad(p: wp.vec3, attr: vec_pluginattr) -> wp.vec3:
grad = wp.vec3()
eps = 1e-6
f_original = bolt(p, attr)
Expand Down
8 changes: 5 additions & 3 deletions mujoco_warp/test_data/collision_sdf/gear.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import warp as wp

from mujoco_warp._src.types import vec_pluginattr


@wp.func
def Subtraction(a: float, b: float) -> float:
Expand Down Expand Up @@ -44,7 +46,7 @@ def mod(x: float, y: float) -> float:


@wp.func
def distance2D(p: wp.vec3, attributes: wp.vec3) -> float:
def distance2D(p: wp.vec3, attributes: vec_pluginattr) -> float:
# see https://www.shadertoy.com/view/3lG3WR
D = 2.8
N = 25.0
Expand Down Expand Up @@ -120,13 +122,13 @@ def distance2D(p: wp.vec3, attributes: wp.vec3) -> float:


@wp.func
def gear(p: wp.vec3, attr: wp.vec3) -> float:
def gear(p: wp.vec3, attr: vec_pluginattr) -> float:
thickness = 0.2
return extrusion(p, distance2D(p, attr), thickness / 2.0)


@wp.func
def gear_sdf_grad(p: wp.vec3, attr: wp.vec3) -> wp.vec3:
def gear_sdf_grad(p: wp.vec3, attr: vec_pluginattr) -> wp.vec3:
grad = wp.vec3()
eps = 1e-6
f_original = gear(p, attr)
Expand Down
6 changes: 4 additions & 2 deletions mujoco_warp/test_data/collision_sdf/nut.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import warp as wp

from mujoco_warp._src.types import vec_pluginattr


@wp.func
def Fract(x: float) -> float:
Expand All @@ -22,7 +24,7 @@ def Intersection(a: float, b: float) -> float:


@wp.func
def nut(p: wp.vec3, attr: wp.vec3) -> float:
def nut(p: wp.vec3, attr: vec_pluginattr) -> float:
screw = 12.0
radius2 = wp.sqrt(p[0] * p[0] + p[1] * p[1]) - attr[0]
sqrt12 = wp.sqrt(2.0) / 2.0
Expand All @@ -47,7 +49,7 @@ def nut(p: wp.vec3, attr: wp.vec3) -> float:


@wp.func
def nut_sdf_grad(p: wp.vec3, attr: wp.vec3) -> wp.vec3:
def nut_sdf_grad(p: wp.vec3, attr: vec_pluginattr) -> wp.vec3:
grad = wp.vec3()
eps = 1e-6
f_original = nut(p, attr)
Expand Down
6 changes: 4 additions & 2 deletions mujoco_warp/test_data/collision_sdf/torus.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import warp as wp

from mujoco_warp._src.types import vec_pluginattr


@wp.func
def torus(p: wp.vec3, attr: wp.vec3) -> wp.float32:
def torus(p: wp.vec3, attr: vec_pluginattr) -> wp.float32:
major_radius = attr[0]
minor_radius = attr[1]

Expand All @@ -14,7 +16,7 @@ def torus(p: wp.vec3, attr: wp.vec3) -> wp.float32:


@wp.func
def torus_sdf_grad(p: wp.vec3, attr: wp.vec3) -> wp.vec3:
def torus_sdf_grad(p: wp.vec3, attr: vec_pluginattr) -> wp.vec3:
grad = wp.vec3()
major_radius = attr[0]
minor_val = attr[1]
Expand Down
6 changes: 4 additions & 2 deletions mujoco_warp/test_data/collision_sdf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import mujoco
import warp as wp

from mujoco_warp._src.types import vec_pluginattr

from .bolt import bolt
from .bolt import bolt_sdf_grad
from .gear import gear
Expand Down Expand Up @@ -81,7 +83,7 @@ def register_sdf_plugins(mjwarp) -> Dict[str, int]:
sdf_types[SDFType.GEAR] = int(m.plugin[i])

@wp.func
def user_sdf(p: wp.vec3, attr: wp.vec3, sdf_type: int) -> float:
def user_sdf(p: wp.vec3, attr: vec_pluginattr, sdf_type: int) -> float:
result = 0.0
if sdf_type == wp.static(sdf_types[SDFType.NUT]):
result = nut(p, attr)
Expand All @@ -94,7 +96,7 @@ def user_sdf(p: wp.vec3, attr: wp.vec3, sdf_type: int) -> float:
return result

@wp.func
def user_sdf_grad(p: wp.vec3, attr: wp.vec3, sdf_type: int) -> wp.vec3:
def user_sdf_grad(p: wp.vec3, attr: vec_pluginattr, sdf_type: int) -> wp.vec3:
if sdf_type == wp.static(sdf_types[SDFType.NUT]):
return nut_sdf_grad(p, attr)
elif sdf_type == wp.static(sdf_types[SDFType.BOLT]):
Expand Down
Loading