Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions examples/coupling/rigid_mpm_attachment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
MPM to Rigid Link Attachment

Demonstrates attaching MPM particles to rigid links using soft constraints.
"""

import argparse
import os

import torch

import genesis as gs


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-v", "--vis", action="store_true", default=False)
parser.add_argument("-c", "--cpu", action="store_true", default=False)
args = parser.parse_args()

gs.init(backend=gs.cpu if args.cpu else gs.gpu)

scene = gs.Scene(
sim_options=gs.options.SimOptions(dt=2e-3, substeps=20),
mpm_options=gs.options.MPMOptions(
lower_bound=(-1.0, -1.0, 0.0),
upper_bound=(1.0, 1.0, 1.5),
grid_density=64,
),
viewer_options=gs.options.ViewerOptions(
camera_pos=(1.5, 0.0, 0.8),
camera_lookat=(0.0, 0.0, 0.4),
),
show_viewer=args.vis,
)

scene.add_entity(gs.morphs.Plane())

rigid_box = scene.add_entity(
gs.morphs.Box(pos=(0.0, 0.0, 0.55), size=(0.12, 0.12, 0.05), fixed=False),
)

mpm_cube = scene.add_entity(
material=gs.materials.MPM.Elastic(E=5e4, nu=0.3, rho=1000),
morph=gs.morphs.Box(pos=(0.0, 0.0, 0.35), size=(0.15, 0.15, 0.15)),
)

scene.build()

# Attach top particles of MPM cube to the rigid box
mask = mpm_cube.get_particles_in_bbox((-0.08, -0.08, 0.41), (0.08, 0.08, 0.44))
mpm_cube.set_particle_constraints(mask, rigid_box.links[0].idx, stiffness=1e5)

n_steps = 500 if "PYTEST_VERSION" not in os.environ else 1
initial_z = 0.55

for i in range(n_steps):
z_offset = 0.15 * (1 - abs((i % 200) - 100) / 100.0)
target_qpos = torch.tensor([0.0, 0.0, initial_z + z_offset, 1.0, 0.0, 0.0, 0.0], device=gs.device)
rigid_box.set_qpos(target_qpos)
scene.step()


if __name__ == "__main__":
main()
95 changes: 95 additions & 0 deletions genesis/engine/entities/mpm_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,98 @@ def get_free(self):
free = self._sanitize_particles_tensor(None, gs.tc_bool)
self.solver._kernel_get_particles_free(self._particle_start, self._n_particles, free)
return free

# ------------------------------------------------------------------------------------
# ------------------------------ particle constraints --------------------------------
# ------------------------------------------------------------------------------------

@gs.assert_built
def get_particles_in_bbox(self, bbox_min, bbox_max):
"""
Get boolean mask for particles within a bounding box.

Parameters
----------
bbox_min : array_like, shape (3,)
Minimum corner of the bounding box [x, y, z].
bbox_max : array_like, shape (3,)
Maximum corner of the bounding box [x, y, z].

Returns
-------
mask : torch.Tensor, shape (n_envs, n_particles)
Boolean mask where True indicates particle is within the bounding box.
"""
bbox_min = torch.as_tensor(bbox_min, dtype=gs.tc_float, device=gs.device)
bbox_max = torch.as_tensor(bbox_max, dtype=gs.tc_float, device=gs.device)

# Get particle positions: shape (n_envs, n_particles, 3)
poss = self.get_particles_pos()
if poss.ndim == 2:
poss = poss.unsqueeze(0) # (1, n_particles, 3)

# Vectorized bbox check: (n_envs, n_particles)
mask = ((bbox_min <= poss) & (poss <= bbox_max)).all(dim=-1)
return mask

@gs.assert_built
def set_particle_constraints(self, particles_mask, link_idx, stiffness):
"""
Attach MPM particles to a rigid link using soft constraints.

The particles will be pulled toward their relative position on the link
using spring forces with critical damping.

Parameters
----------
particles_mask : torch.Tensor, shape (n_envs, n_particles)
Boolean mask indicating which particles to constrain.
link_idx : int
Index of the rigid link to attach particles to.
stiffness : float
Spring stiffness for the constraint.
"""
if not isinstance(link_idx, int):
gs.raise_exception("link_idx must be an integer.")

if not self._solver._constraints_initialized:
self._solver.init_constraints()

# Get link position and quaternion for all envs
rigid_solver = self._sim.coupler.rigid_solver
link_pos = rigid_solver.get_links_pos(links_idx=[link_idx]) # (n_envs, 1, 3)
link_quat = rigid_solver.get_links_quat(links_idx=[link_idx]) # (n_envs, 1, 4)
if link_pos.ndim == 2:
link_pos = link_pos.unsqueeze(0)
link_quat = link_quat.unsqueeze(0)
link_pos = link_pos[:, 0, :] # (n_envs, 3)
link_quat = link_quat[:, 0, :] # (n_envs, 4)

self._solver._kernel_set_particle_constraints(
self._sim.cur_substep_local,
particles_mask,
self._particle_start,
stiffness,
link_idx,
link_pos,
link_quat,
)

@gs.assert_built
def remove_particle_constraints(self, particles_mask=None):
"""
Remove constraints from specified particles, or all if None.

Parameters
----------
particles_mask : torch.Tensor, shape (n_envs, n_particles), optional
Boolean mask indicating which particles to unconstrain. If None, removes all constraints for this entity.
"""
if not self._solver._constraints_initialized:
return

# Remove all constraints for this entity if mask not specified
if particles_mask is None:
particles_mask = torch.ones((self._sim._B, self.n_particles), dtype=torch.bool, device=gs.device)

self._solver._kernel_remove_particle_constraints(particles_mask, self._particle_start)
118 changes: 118 additions & 0 deletions genesis/engine/solvers/mpm_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, scene: "Scene", sim: "Simulator", options: "MPMOptions"):
self._upper_bound = np.array(options.upper_bound)
self._lower_bound = np.array(options.lower_bound)
self._enable_CPIC = options.enable_CPIC
self._constraints_initialized = False

self._n_vvert_supports = self.scene.vis_options.n_support_neighbors

Expand Down Expand Up @@ -163,6 +164,32 @@ def init_vvert_fields(self):
def init_ckpt(self):
self._ckpt = dict()

def init_constraints(self):
"""Lazy initialization of particle constraint fields."""
# Memory check: ensure index fits in int32
if self._n_particles * self._B * 3 > np.iinfo(np.int32).max:
gs.raise_exception(
f"Particle constraint shape (n_envs={self._B}, n_particles={self._n_particles}, 3) is too large. "
"Consider reducing n_envs or n_particles."
)

self._constraints_initialized = True

particle_constraint_info = ti.types.struct(
is_constrained=gs.ti_bool, # whether particle is constrained
target_pos=gs.ti_vec3, # target position for the constraint
stiffness=gs.ti_float, # spring stiffness
link_idx=gs.ti_int, # index of the rigid link (-1 if not linked)
link_local_pos=gs.ti_vec3, # offset from link origin in link's local frame
)

self.particle_constraints = particle_constraint_info.field(
shape=(self._n_particles, self._B), needs_grad=False, layout=ti.Layout.AOS
)

self.particle_constraints.is_constrained.fill(False)
self.particle_constraints.link_idx.fill(-1)

def reset_grad(self):
self.particles.grad.fill(0.0)
self.grid.grad.fill(0.0)
Expand Down Expand Up @@ -502,6 +529,11 @@ def substep_post_coupling(self, f):
self.sim.coupler.rigid_solver.links_state,
self.sim.coupler.rigid_solver._rigid_global_info,
)

# Apply particle constraints after g2p
if self._constraints_initialized:
self.apply_particle_constraints(f, self.sim.coupler.rigid_solver.links_state)

# FIXME: Use existing errno mechanism for this.
if not self._is_state_valid(f):
gs.raise_exception(
Expand Down Expand Up @@ -1008,6 +1040,88 @@ def _kernel_get_mass(
for i_b_ in range(envs_idx.shape[0]):
mass[i_b_] = total_mass

# ------------------------------------------------------------------------------------
# -------------------------------- particle constraints ------------------------------
# ------------------------------------------------------------------------------------

@ti.kernel
def _kernel_set_particle_constraints(
self,
f: ti.i32,
particles_mask: ti.types.ndarray(), # shape [n_envs, n_particles] boolean mask
particle_start: ti.i32,
stiffness: ti.f32,
link_idx: ti.i32,
link_pos: ti.types.ndarray(), # shape [n_envs, 3]
link_quat: ti.types.ndarray(), # shape [n_envs, 4]
):
for i_p_local, i_b in ti.ndrange(particles_mask.shape[1], particles_mask.shape[0]):
if particles_mask[i_b, i_p_local]:
i_p = i_p_local + particle_start

# Get current particle position
pos = self.particles[f, i_p, i_b].pos

# Get link transform
l_pos = ti.Vector([link_pos[i_b, 0], link_pos[i_b, 1], link_pos[i_b, 2]], dt=gs.ti_float)
l_quat = ti.Vector(
[link_quat[i_b, 0], link_quat[i_b, 1], link_quat[i_b, 2], link_quat[i_b, 3]], dt=gs.ti_float
)

# Compute offset in link's local frame
local_pos = gu.ti_inv_transform_by_trans_quat(pos, l_pos, l_quat)

# Store constraint info
self.particle_constraints[i_p, i_b].is_constrained = True
self.particle_constraints[i_p, i_b].target_pos = pos # initial target is current position
self.particle_constraints[i_p, i_b].stiffness = stiffness
self.particle_constraints[i_p, i_b].link_idx = link_idx
self.particle_constraints[i_p, i_b].link_local_pos = local_pos

@ti.kernel
def _kernel_remove_particle_constraints(
self,
particles_mask: ti.types.ndarray(), # shape [n_envs, n_particles] boolean mask
particle_start: ti.i32,
):
for i_p_local, i_b in ti.ndrange(particles_mask.shape[1], particles_mask.shape[0]):
if particles_mask[i_b, i_p_local]:
i_p = i_p_local + particle_start
self.particle_constraints[i_p, i_b].is_constrained = False
self.particle_constraints[i_p, i_b].link_idx = -1

@ti.kernel
def apply_particle_constraints(
self,
f: ti.i32,
links_state: array_class.LinksState,
):
for i_p, i_b in ti.ndrange(self._n_particles, self._B):
if self.particle_constraints[i_p, i_b].is_constrained:
# Update target position from link pose
i_l = self.particle_constraints[i_p, i_b].link_idx
if i_l >= 0:
link_pos = links_state.pos[i_l, i_b]
link_quat = links_state.quat[i_l, i_b]
local_pos = self.particle_constraints[i_p, i_b].link_local_pos
target = gu.ti_transform_by_trans_quat(local_pos, link_pos, link_quat)
self.particle_constraints[i_p, i_b].target_pos = target

# Apply spring force to velocity
target_pos = self.particle_constraints[i_p, i_b].target_pos
stiffness = self.particle_constraints[i_p, i_b].stiffness
mass = self.particles_info[i_p].mass / self._particle_volume_scale

pos = self.particles[f + 1, i_p, i_b].pos
vel = self.particles[f + 1, i_p, i_b].vel

pos_error = pos - target_pos
spring_force = -stiffness * pos_error
damping_force = -2.0 * ti.math.sqrt(stiffness * mass) * vel

dv = self.substep_dt * (spring_force + damping_force) / mass
self.particles[f + 1, i_p, i_b].vel = vel + dv

# ------------------------------------------------------------------------------------
# ----------------------------------- properties -------------------------------------
# ------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1102,6 +1216,10 @@ def grid_offset(self):
def enable_CPIC(self):
return self._enable_CPIC

@property
def enable_particle_constraints(self):
return self._enable_particle_constraints


@ti.func
def signmax(a, eps):
Expand Down
58 changes: 58 additions & 0 deletions tests/test_deformable_physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,64 @@ def test_deformable_parallel(show_viewer):
assert_allclose(water.get_particles_vel(), 0.0, atol=5e-2)


@pytest.mark.required
def test_mpm_particle_constraints(show_viewer):
Comment on lines +224 to +225
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should flag this test as slow if it is (>200s). Beware that it means that it will only run on production CI, so only do this if necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not slow on my machine, but I'm happy to add the tag to make our generic CI not too slow

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not slow on my machine, but I'm happy to add the tag to make our generic CI not too slow

The speed must be benched on the generic CI runners. Not your machine. So you need to run the entire CI pipeline without "slow" mark at least once, otherwise you will never know since it will be skipped.

"""Test MPM particle constraints: bbox selection, attachment, and following."""
scene = gs.Scene(
sim_options=gs.options.SimOptions(
dt=2e-3,
substeps=20,
),
mpm_options=gs.options.MPMOptions(
lower_bound=(-1.0, -1.0, 0.0),
upper_bound=(1.0, 1.0, 1.0),
grid_density=64,
),
show_viewer=show_viewer,
show_FPS=False,
)
scene.add_entity(gs.morphs.Plane())
rigid_box = scene.add_entity(
gs.morphs.Box(
pos=(0, 0, 0.55),
size=(0.12, 0.12, 0.05),
fixed=True,
),
)
mpm_cube = scene.add_entity(
material=gs.materials.MPM.Elastic(
E=5e4,
nu=0.3,
rho=1000,
),
morph=gs.morphs.Box(
pos=(0, 0, 0.35),
size=(0.15, 0.15, 0.15),
),
)
scene.build(n_envs=2)

# Test get_particles_in_bbox - returns (n_envs, n_particles) mask
mask = mpm_cube.get_particles_in_bbox((-0.08, -0.08, 0.41), (0.08, 0.08, 0.44))
assert mask.shape == (2, mpm_cube.n_particles), "mask should be (n_envs, n_particles)"
assert mask.any(), "bbox should select some particles"
assert not mask.all(), "bbox should not select all particles"

# Attach and test following
link_idx = rigid_box.links[0].idx
mpm_cube.set_particle_constraints(mask, link_idx, stiffness=1e5)
initial_rigid_pos = rigid_box.get_pos().clone()
initial_mpm_x = mpm_cube.get_particles_pos()[:, mask[0], 0].mean()

pos_diff = torch.tensor([0.2, 0, 0], device=gs.device)
rigid_box.set_pos(initial_rigid_pos + pos_diff, zero_velocity=False)
for _ in range(30):
scene.step()

mpm_diff = mpm_cube.get_particles_pos()[:, mask[0], 0].mean() - initial_mpm_x
assert mpm_diff > pos_diff[0] * 0.3, f"MPM should follow rigid link. Got {mpm_diff:.3f}"


def test_sf_solver(show_viewer):
import gstaichi as ti

Expand Down
Loading