Skip to content

Commit ced9c9b

Browse files
committed
Refactor BVH
1 parent cbcb417 commit ced9c9b

File tree

6 files changed

+297
-350
lines changed

6 files changed

+297
-350
lines changed

rustfmt.toml

+2
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ imports_layout = "Mixed"
33
imports_granularity = "Crate"
44
group_imports = "StdExternalCrate"
55
max_width = 120
6+
fn_call_width = 120
7+
chain_width = 120

src/array2.rs

+2-10
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,14 @@ impl<T> AsRef<[T]> for Array2<T> {
5252
impl<T> std::ops::Index<(usize, usize)> for Array2<T> {
5353
type Output = T;
5454
fn index(&self, index: (usize, usize)) -> &Self::Output {
55-
assert!(
56-
index.0 < self.size.0 && index.1 < self.size.1,
57-
"index {index:?} is out of bounds {:?}",
58-
self.size
59-
);
55+
assert!(index.0 < self.size.0 && index.1 < self.size.1, "index {index:?} is out of bounds {:?}", self.size);
6056
&self.data[index.1 * self.size.0 + index.0]
6157
}
6258
}
6359

6460
impl<T> std::ops::IndexMut<(usize, usize)> for Array2<T> {
6561
fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
66-
assert!(
67-
index.0 < self.size.0 && index.1 < self.size.1,
68-
"index {index:?} is out of bounds {:?}",
69-
self.size
70-
);
62+
assert!(index.0 < self.size.0 && index.1 < self.size.1, "index {index:?} is out of bounds {:?}", self.size);
7163
&mut self.data[index.1 * self.size.0 + index.0]
7264
}
7365
}

src/bvh.cl

+32-32
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,11 @@ typedef struct {
33
double2 bottomright;
44
} AABB;
55

6-
enum {
6+
typedef enum {
77
TAG_LEAF,
88
TAG_TREE
9-
};
9+
} NodeTag;
1010

11-
typedef struct {
12-
uint object2_index;
13-
} Leaf;
14-
15-
typedef struct {
16-
uint left;
17-
uint right;
18-
} Tree;
19-
20-
typedef struct {
21-
uint tag;
22-
union {
23-
Leaf leaf;
24-
Tree tree;
25-
};
26-
} NodeKind;
27-
28-
typedef struct {
29-
AABB aabb;
30-
NodeKind kind;
31-
} Node;
3211

3312
bool intersects(const AABB *a, const AABB *b) {
3413
return a->topleft.x <= b->bottomright.x
@@ -42,7 +21,11 @@ bool intersects(const AABB *a, const AABB *b) {
4221
uint find_intersections_with(
4322
uint root,
4423
uint object1_index,
45-
global const Node *bvh_nodes,
24+
global const AABB *bvh_node_aabbs,
25+
global const NodeTag *bvh_node_tags,
26+
global const uint *bvh_node_leaf_indices,
27+
global const uint *bvh_node_tree_left,
28+
global const uint *bvh_node_tree_right,
4629
global const AABB *object_aabbs,
4730
global const double2 *positions,
4831
global const double *radii,
@@ -60,10 +43,10 @@ uint find_intersections_with(
6043

6144
while (sp > 0) {
6245
uint node_id = stack[--sp];
63-
const Node node = bvh_nodes[node_id];
64-
if (intersects(&object_aabb, &node.aabb)) {
65-
if (node.kind.tag == TAG_LEAF) {
66-
uint object2_index = node.kind.leaf.object2_index;
46+
const AABB aabb = bvh_node_aabbs[node_id];
47+
if (intersects(&object_aabb, &aabb)) {
48+
if (bvh_node_tags[node_id] == TAG_LEAF) {
49+
uint object2_index = bvh_node_leaf_indices[node_id];
6750
if (object2_index != object1_index) {
6851
const double2 object2_position = positions[object2_index];
6952
const double object2_radius = radii[object2_index];
@@ -77,8 +60,8 @@ uint find_intersections_with(
7760
}
7861
}
7962
} else {
80-
stack[sp++] = node.kind.tree.left;
81-
stack[sp++] = node.kind.tree.right;
63+
stack[sp++] = bvh_node_tree_left[node_id];
64+
stack[sp++] = bvh_node_tree_right[node_id];
8265
}
8366
}
8467
}
@@ -87,7 +70,11 @@ uint find_intersections_with(
8770

8871
kernel void bvh_find_candidates(
8972
const uint root,
90-
global const Node *bvh_nodes,
73+
global const AABB *bvh_node_aabbs,
74+
global const NodeTag *bvh_node_tags,
75+
global const uint *bvh_node_leaf_indices,
76+
global const uint *bvh_node_tree_left,
77+
global const uint *bvh_node_tree_right,
9178
global const AABB *object_aabbs,
9279
global const double2 *positions,
9380
global const double *radii,
@@ -99,7 +86,20 @@ kernel void bvh_find_candidates(
9986
for (uint i = 0; i < MAX_CANDIDATES; ++i) {
10087
candidates[i] = (uint2)(0, 0);
10188
}
102-
uint candidates_end = find_intersections_with(root, object1_index, bvh_nodes, object_aabbs, positions, radii, candidates, 0);
89+
uint candidates_end = find_intersections_with(
90+
root,
91+
object1_index,
92+
bvh_node_aabbs,
93+
bvh_node_tags,
94+
bvh_node_leaf_indices,
95+
bvh_node_tree_left,
96+
bvh_node_tree_right,
97+
object_aabbs,
98+
positions,
99+
radii,
100+
candidates,
101+
0
102+
);
103103
for (uint i = 0; i < candidates_end; ++i) {
104104
global_candidates[object1_index * max_candidates + i] = candidates[i];
105105
}

src/bvh.rs

+93-42
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::mem::swap;
1+
use std::{mem::swap, u32};
22

33
use itertools::Itertools;
44
use rayon::slice::ParallelSliceMut;
@@ -7,10 +7,12 @@ use crate::{physics::NormalizedCollisionPair, vector2::Vector2};
77

88
#[derive(Default, Clone)]
99
pub struct Bvh {
10-
nodes: Vec<Node>,
10+
node_aabbs: Vec<AABB>,
11+
node_tags: Vec<NodeTag>,
12+
node_leaf_indices: Vec<u32>,
13+
node_tree_left: Vec<NodeId>,
14+
node_tree_right: Vec<NodeId>,
1115
object_aabbs: Vec<AABB>,
12-
object_positions: Vec<Vector2<f64>>,
13-
object_radii: Vec<f64>,
1416
}
1517

1618
impl Bvh {
@@ -32,16 +34,21 @@ impl Bvh {
3234
});
3335
}
3436
items.par_sort_unstable_by_key(|item| item.morton_code);
35-
let mut nodes = Vec::with_capacity(positions.len() * 2);
37+
38+
let mut node_aabbs = Vec::default();
39+
let mut node_tags = Vec::default();
40+
let mut node_leaf_indices = Vec::default();
41+
let mut node_tree_left = Vec::default();
42+
let mut node_tree_right = Vec::default();
3643
for item in &items {
37-
nodes.push(Node {
38-
aabb: object_aabbs[item.object_index],
39-
kind: NodeKind::Leaf(u32::try_from(item.object_index).unwrap()),
40-
});
44+
node_aabbs.push(object_aabbs[item.object_index]);
45+
node_tags.push(NodeTag::Leaf);
46+
node_leaf_indices.push(u32::try_from(item.object_index).unwrap());
47+
node_tree_left.push(NodeId::INVALID);
48+
node_tree_right.push(NodeId::INVALID);
4149
}
42-
let mut combine_area = (0..nodes.len())
43-
.map(|i| NodeId(u32::try_from(i).unwrap()))
44-
.collect_vec();
50+
51+
let mut combine_area = (0..node_aabbs.len()).map(|i| NodeId(u32::try_from(i).unwrap())).collect_vec();
4552
let mut combine_area_tmp = Vec::with_capacity(combine_area.len().div_ceil(2));
4653
let (mut combine_area, mut combine_area_tmp) = (&mut combine_area, &mut combine_area_tmp);
4754
while combine_area.len() > 1 {
@@ -52,8 +59,8 @@ impl Bvh {
5259
} else {
5360
let left = chunk[0];
5461
let right = chunk[1];
55-
let left_aabb = nodes[usize::try_from(left.0).unwrap()].aabb;
56-
let right_aabb = nodes[usize::try_from(right.0).unwrap()].aabb;
62+
let left_aabb = node_aabbs[usize::try_from(left.0).unwrap()];
63+
let right_aabb = node_aabbs[usize::try_from(right.0).unwrap()];
5764
let aabb = AABB {
5865
topleft: Vector2::new(
5966
left_aabb.topleft.x.min(right_aabb.topleft.x),
@@ -64,45 +71,76 @@ impl Bvh {
6471
left_aabb.bottomright.y.max(right_aabb.bottomright.y),
6572
),
6673
};
67-
let node_id = NodeId(u32::try_from(nodes.len()).unwrap());
68-
nodes.push(Node {
69-
aabb,
70-
kind: NodeKind::Tree { left, right },
71-
});
74+
let node_id = NodeId(u32::try_from(node_aabbs.len()).unwrap());
75+
node_aabbs.push(aabb);
76+
node_tags.push(NodeTag::Tree);
77+
node_leaf_indices.push(u32::MAX);
78+
node_tree_left.push(left);
79+
node_tree_right.push(right);
7280
node_id
7381
};
7482
combine_area_tmp.push(node_id);
7583
}
7684
swap(&mut combine_area, &mut combine_area_tmp);
7785
}
7886
Bvh {
87+
node_aabbs,
88+
node_tags,
89+
node_leaf_indices,
90+
node_tree_left,
91+
node_tree_right,
7992
object_aabbs,
80-
object_positions: positions.to_vec(),
81-
object_radii: radii.to_vec(),
82-
nodes,
8393
}
8494
}
8595

86-
pub fn find_intersections(&self, object_index: usize, candidates: &mut Vec<NormalizedCollisionPair>) {
87-
if !self.nodes.is_empty() {
88-
self.find_intersections_with(object_index, candidates);
89-
}
96+
pub fn node_aabbs(&self) -> &[AABB] {
97+
&self.node_aabbs
98+
}
99+
100+
pub fn node_tags(&self) -> &[NodeTag] {
101+
&self.node_tags
102+
}
103+
104+
pub fn node_leaf_indices(&self) -> &[u32] {
105+
&self.node_leaf_indices
106+
}
107+
108+
pub fn node_tree_left(&self) -> &[NodeId] {
109+
&self.node_tree_left
90110
}
91111

92-
pub fn nodes(&self) -> &[Node] {
93-
&self.nodes
112+
pub fn node_tree_right(&self) -> &[NodeId] {
113+
&self.node_tree_right
94114
}
95115

96116
pub fn root(&self) -> NodeId {
97-
let id = self.nodes.len().checked_sub(1).unwrap();
117+
let id = self.node_aabbs.len().checked_sub(1).unwrap();
98118
NodeId(u32::try_from(id).unwrap())
99119
}
100120

101121
pub fn object_aabbs(&self) -> &[AABB] {
102122
&self.object_aabbs
103123
}
104124

105-
fn find_intersections_with(&self, object1_index: usize, candidates: &mut Vec<NormalizedCollisionPair>) {
125+
pub fn find_intersections(
126+
&self,
127+
object_index: usize,
128+
positions: &[Vector2<f64>],
129+
radii: &[f64],
130+
candidates: &mut Vec<NormalizedCollisionPair>,
131+
) {
132+
if !self.node_aabbs.is_empty() {
133+
self.find_intersections_with(object_index, positions, radii, candidates);
134+
}
135+
}
136+
137+
fn find_intersections_with(
138+
&self,
139+
object1_index: usize,
140+
positions: &[Vector2<f64>],
141+
radii: &[f64],
142+
candidates: &mut Vec<NormalizedCollisionPair>,
143+
) {
106144
const STACK_SIZE: usize = 16;
107145
let mut stack = [NodeId(0); STACK_SIZE];
108146
let mut sp = 0;
@@ -113,27 +151,29 @@ impl Bvh {
113151
sp -= 1;
114152
let node_id = stack[sp];
115153
let object_aabb = self.object_aabbs[object1_index];
116-
let Node { aabb, kind, .. } = &self.nodes[usize::try_from(node_id.0).unwrap()];
117-
if object_aabb.intersects(aabb) {
118-
match *kind {
119-
NodeKind::Leaf(object2_index) => {
120-
let object2_index = usize::try_from(object2_index).unwrap();
154+
let node_id = usize::try_from(node_id.0).unwrap();
155+
let aabb = self.node_aabbs[node_id];
156+
if object_aabb.intersects(&aabb) {
157+
let tag = self.node_tags[node_id];
158+
match tag {
159+
NodeTag::Leaf => {
160+
let object2_index = usize::try_from(self.node_leaf_indices[node_id]).unwrap();
121161
if object2_index != object1_index {
122-
let object1_position = self.object_positions[object1_index];
123-
let object2_position = self.object_positions[object2_index];
124-
let object1_radius = self.object_radii[object1_index];
125-
let object2_radius = self.object_radii[object2_index];
162+
let object1_position = positions[object1_index];
163+
let object2_position = positions[object2_index];
164+
let object1_radius = radii[object1_index];
165+
let object2_radius = radii[object2_index];
126166
let distance_squared = (object1_position - object2_position).magnitude_squared();
127167
let collision_distance = object1_radius + object2_radius;
128168
if distance_squared < collision_distance * collision_distance {
129169
candidates.push(NormalizedCollisionPair::new(object1_index, object2_index));
130170
}
131171
}
132172
}
133-
NodeKind::Tree { left, right } => {
134-
stack[sp] = left;
173+
NodeTag::Tree => {
174+
stack[sp] = self.node_tree_left[node_id];
135175
sp += 1;
136-
stack[sp] = right;
176+
stack[sp] = self.node_tree_right[node_id];
137177
sp += 1;
138178
}
139179
}
@@ -180,6 +220,17 @@ impl AABB {
180220
#[derive(Default, Debug, Clone, Copy)]
181221
pub struct NodeId(u32);
182222

223+
impl NodeId {
224+
const INVALID: Self = Self(u32::MAX);
225+
}
226+
227+
#[repr(C)]
228+
#[derive(Clone, Copy)]
229+
pub enum NodeTag {
230+
Leaf,
231+
Tree,
232+
}
233+
183234
#[repr(C)]
184235
#[derive(Clone, Copy)]
185236
pub struct Node {

0 commit comments

Comments
 (0)