Skip to content

Commit

Permalink
Merge branch 'contact-limiting' into 'main'
Browse files Browse the repository at this point in the history
Make rigid body contact pair-wise limiting optional

See merge request omniverse/warp!564
  • Loading branch information
christophercrouzet committed Jun 12, 2024
2 parents 94df2fe + 95937ad commit 3bb42dd
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 34 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# CHANGELOG

## [Upcoming Release] - 2024-??-??

- Improve memory usage and performance for rigid body contact handling when `self.rigid_mesh_contact_max` is zero (default behavior)

## [1.2.0] - 2024-06-06

- Add a not-a-number floating-point constant that can be used as `wp.NAN` or `wp.nan`.
Expand Down
64 changes: 35 additions & 29 deletions warp/sim/collide.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,10 +859,9 @@ def broadphase_collision_pairs(
contact_shape0[index + num_contacts_a + i] = actual_shape_b
contact_shape1[index + num_contacts_a + i] = actual_shape_a
contact_point_id[index + num_contacts_a + i] = i
contact_point_limit[pair_index_ab] = 2
if mesh_contact_max > 0:
if mesh_contact_max > 0 and contact_point_limit and pair_index_ba < contact_point_limit.shape[0]:
num_contacts_b = wp.min(mesh_contact_max, num_contacts_b)
contact_point_limit[pair_index_ba] = num_contacts_b
contact_point_limit[pair_index_ba] = num_contacts_b
return
else:
num_contacts = 2
Expand All @@ -877,13 +876,11 @@ def broadphase_collision_pairs(
contact_shape0[index + i] = shape_a
contact_shape1[index + i] = shape_b
contact_point_id[index + i] = i
contact_point_limit[pair_index_ab] = 12
# allocate contact points from box B against A
for i in range(12):
contact_shape0[index + 12 + i] = shape_b
contact_shape1[index + 12 + i] = shape_a
contact_point_id[index + 12 + i] = i
contact_point_limit[pair_index_ba] = 12
return
elif actual_type_b == wp.sim.GEO_MESH:
num_contacts_a = 8
Expand All @@ -908,10 +905,9 @@ def broadphase_collision_pairs(
contact_shape1[index + num_contacts_a + i] = actual_shape_a
contact_point_id[index + num_contacts_a + i] = i

contact_point_limit[pair_index_ab] = num_contacts_a
if mesh_contact_max > 0:
if mesh_contact_max > 0 and contact_point_limit and pair_index_ba < contact_point_limit.shape[0]:
num_contacts_b = wp.min(mesh_contact_max, num_contacts_b)
contact_point_limit[pair_index_ba] = num_contacts_b
contact_point_limit[pair_index_ba] = num_contacts_b
return
elif actual_type_b == wp.sim.GEO_PLANE:
if geo.scale[actual_shape_b][0] == 0.0 and geo.scale[actual_shape_b][1] == 0.0:
Expand Down Expand Up @@ -947,11 +943,13 @@ def broadphase_collision_pairs(
contact_shape1[index + num_contacts_a + i] = actual_shape_a
contact_point_id[index + num_contacts_a + i] = i

if mesh_contact_max > 0:
if mesh_contact_max > 0 and contact_point_limit:
num_contacts_a = wp.min(mesh_contact_max, num_contacts_a)
num_contacts_b = wp.min(mesh_contact_max, num_contacts_b)
contact_point_limit[pair_index_ab] = num_contacts_a
contact_point_limit[pair_index_ba] = num_contacts_b
if pair_index_ab < contact_point_limit.shape[0]:
contact_point_limit[pair_index_ab] = num_contacts_a
if pair_index_ba < contact_point_limit.shape[0]:
contact_point_limit[pair_index_ba] = num_contacts_b
return
elif actual_type_a == wp.sim.GEO_PLANE:
return # no plane-plane contacts
Expand All @@ -969,8 +967,11 @@ def broadphase_collision_pairs(
contact_shape0[cp_index] = actual_shape_a
contact_shape1[cp_index] = actual_shape_b
contact_point_id[cp_index] = i
contact_point_limit[pair_index_ab] = num_contacts
contact_point_limit[pair_index_ba] = 0
if contact_point_limit:
if pair_index_ab < contact_point_limit.shape[0]:
contact_point_limit[pair_index_ab] = num_contacts
if pair_index_ba < contact_point_limit.shape[0]:
contact_point_limit[pair_index_ba] = 0


@wp.kernel
Expand Down Expand Up @@ -1005,12 +1006,14 @@ def handle_contact_pairs(
if shape_a == shape_b:
return

if contact_point_limit:
pair_index = shape_a * num_shapes + shape_b
contact_limit = contact_point_limit[pair_index]
if contact_pairwise_counter[pair_index] >= contact_limit:
# reached limit of contact points per contact pair
return

point_id = contact_point_id[tid]
pair_index = shape_a * num_shapes + shape_b
contact_limit = contact_point_limit[pair_index]
if contact_pairwise_counter[pair_index] >= contact_limit:
# reached limit of contact points per contact pair
return

rigid_a = shape_body[shape_a]
X_wb_a = wp.transform_identity()
Expand Down Expand Up @@ -1404,15 +1407,16 @@ def handle_contact_pairs(

d = distance - thickness
if d < rigid_contact_margin:
pair_contact_id = limited_counter_increment(
contact_pairwise_counter, pair_index, contact_tids, tid, contact_limit
)
if pair_contact_id == -1:
# wp.printf("Reached contact point limit %d >= %d for shape pair %d and %d (pair_index: %d)\n",
# contact_pairwise_counter[pair_index], contact_limit, shape_a, shape_b, pair_index)
# reached contact point limit
return
index = limited_counter_increment(contact_count, 0, contact_tids, tid, -1)
if contact_pairwise_counter:
pair_contact_id = limited_counter_increment(
contact_pairwise_counter, pair_index, contact_tids, tid, contact_limit
)
if pair_contact_id == -1:
# wp.printf("Reached contact point limit %d >= %d for shape pair %d and %d (pair_index: %d)\n",
# contact_pairwise_counter[pair_index], contact_limit, shape_a, shape_b, pair_index)
# reached contact point limit
return
index = counter_increment(contact_count, 0, contact_tids, tid)
contact_shape0[index] = shape_a
contact_shape1[index] = shape_b
# transform from world into body frame (so the contact point includes the shape transform)
Expand Down Expand Up @@ -1550,14 +1554,16 @@ def collide(model, state, edge_sdf_iter: int = 10, iterate_mesh_vertices: bool =
model.rigid_contact_normal = wp.clone(model.rigid_contact_normal)
model.rigid_contact_thickness = wp.clone(model.rigid_contact_thickness)
model.rigid_contact_count = wp.zeros_like(model.rigid_contact_count)
model.rigid_contact_pairwise_counter = wp.zeros_like(model.rigid_contact_pairwise_counter)
model.rigid_contact_tids = wp.zeros_like(model.rigid_contact_tids)
model.rigid_contact_shape0 = wp.empty_like(model.rigid_contact_shape0)
model.rigid_contact_shape1 = wp.empty_like(model.rigid_contact_shape1)
if model.rigid_contact_pairwise_counter is not None:
model.rigid_contact_pairwise_counter = wp.zeros_like(model.rigid_contact_pairwise_counter)
else:
model.rigid_contact_count.zero_()
model.rigid_contact_pairwise_counter.zero_()
model.rigid_contact_tids.zero_()
if model.rigid_contact_pairwise_counter is not None:
model.rigid_contact_pairwise_counter.zero_()
model.rigid_contact_shape0.fill_(-1)
model.rigid_contact_shape1.fill_(-1)

Expand Down
22 changes: 17 additions & 5 deletions warp/sim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,11 +1014,23 @@ def allocate_rigid_contacts(self, target=None, count=None, limited_contact_count
target.rigid_contact_broad_shape0 = wp.zeros(self.rigid_contact_max, dtype=wp.int32)
target.rigid_contact_broad_shape1 = wp.zeros(self.rigid_contact_max, dtype=wp.int32)

max_pair_count = self.shape_count * self.shape_count
# maximum number of contact points per contact pair
target.rigid_contact_point_limit = wp.zeros(max_pair_count, dtype=wp.int32)
# currently found contacts per contact pair
target.rigid_contact_pairwise_counter = wp.zeros(max_pair_count, dtype=wp.int32)
if self.rigid_mesh_contact_max > 0:
# add additional buffers to track how many contact points are generated per contact pair
# (significantly increases memory usage, only enable if mesh contacts need to be pruned)
if self.shape_count >= 46340:
# clip the number of potential contacts to avoid signed 32-bit integer overflow
# i.e. when the number of shapes exceeds sqrt(2**31 - 1)
max_pair_count = 2**31 - 1
else:
max_pair_count = self.shape_count * self.shape_count
# maximum number of contact points per contact pair
target.rigid_contact_point_limit = wp.zeros(max_pair_count, dtype=wp.int32)
# currently found contacts per contact pair
target.rigid_contact_pairwise_counter = wp.zeros(max_pair_count, dtype=wp.int32)
else:
target.rigid_contact_point_limit = None
target.rigid_contact_pairwise_counter = None

# ID of thread that found the current contact point
target.rigid_contact_tids = wp.zeros(self.rigid_contact_max, dtype=wp.int32)

Expand Down

0 comments on commit 3bb42dd

Please sign in to comment.