Skip to content

Commit a71fce5

Browse files
bors[bot]SkiFire13
andauthored
Merge #518
518: Remove unpredicted branch from kmerge::sift_down r=jswrenn a=SkiFire13 This is pretty much a port from rust-lang/rust#78857 Compared with the previous implementation, this adds more branches for bound checks which aren't present on the stdlib version, however they should be predicted almost always. The speedup should come from the removal of an unpredictable branch from the loop body, in favor of boolean arithmetic. The benchmarks seem to agree: ``` before: test kmerge default ... bench: 6812 ns/iter (+/- 18) test kmerge tenway ... bench: 223673 ns/iter (+/- 769) after: test kmerge default ... bench: 6212 ns/iter (+/- 43) test kmerge tenway ... bench: 190700 ns/iter (+/- 419) ``` Co-authored-by: Giacomo Stevanato <[email protected]>
2 parents 3ced790 + 2cb789c commit a71fce5

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

src/kmerge_impl.rs

+11-7
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,13 @@ fn sift_down<T, S>(heap: &mut [T], index: usize, mut less_than: S)
7474
debug_assert!(index <= heap.len());
7575
let mut pos = index;
7676
let mut child = 2 * pos + 1;
77-
// the `pos` conditional is to avoid a bounds check
78-
while pos < heap.len() && child < heap.len() {
79-
let right = child + 1;
80-
77+
// Require the right child to be present
78+
// This allows to find the index of the smallest child without a branch
79+
// that wouldn't be predicted if present
80+
while child + 1 < heap.len() {
8181
// pick the smaller of the two children
82-
if right < heap.len() && less_than(&heap[right], &heap[child]) {
83-
child = right;
84-
}
82+
// use aritmethic to avoid an unpredictable branch
83+
child += less_than(&heap[child+1], &heap[child]) as usize;
8584

8685
// sift down is done if we are already in order
8786
if !less_than(&heap[child], &heap[pos]) {
@@ -91,6 +90,11 @@ fn sift_down<T, S>(heap: &mut [T], index: usize, mut less_than: S)
9190
pos = child;
9291
child = 2 * pos + 1;
9392
}
93+
// Check if the last (left) child was an only child
94+
// if it is then it has to be compared with the parent
95+
if child + 1 == heap.len() && less_than(&heap[child], &heap[pos]) {
96+
heap.swap(pos, child);
97+
}
9498
}
9599

96100
/// An iterator adaptor that merges an abitrary number of base iterators in ascending order.

0 commit comments

Comments
 (0)