Skip to content

Commit 9fbbe75

Browse files
committed
Auto merge of #95602 - scottmcm:faster-array-intoiter-fold, r=the8472
Fix `array::IntoIter::fold` to use the optimized `Range::fold` It was using `Iterator::by_ref` in the implementation, which ended up pessimizing it enough that, for example, it didn't vectorize when we tried it in the <https://rust-lang.zulipchat.com/#narrow/stream/257879-project-portable-simd/topic/Reducing.20sum.20into.20wider.20types> conversation. Demonstration that the codegen test doesn't pass on the current nightly: <https://rust.godbolt.org/z/Taxev5eMn>
2 parents f1f721e + e8fc7ba commit 9fbbe75

File tree

4 files changed

+141
-1
lines changed

4 files changed

+141
-1
lines changed

library/core/src/array/iter.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ impl<T, const N: usize> Iterator for IntoIter<T, N> {
266266
Fold: FnMut(Acc, Self::Item) -> Acc,
267267
{
268268
let data = &mut self.data;
269-
self.alive.by_ref().fold(init, |acc, idx| {
269+
iter::ByRefSized(&mut self.alive).fold(init, |acc, idx| {
270270
// SAFETY: idx is obtained by folding over the `alive` range, which implies the
271271
// value is currently considered alive but as the range is being consumed each value
272272
// we read here will only be read once and then considered dead.
@@ -323,6 +323,20 @@ impl<T, const N: usize> DoubleEndedIterator for IntoIter<T, N> {
323323
})
324324
}
325325

326+
#[inline]
327+
fn rfold<Acc, Fold>(mut self, init: Acc, mut rfold: Fold) -> Acc
328+
where
329+
Fold: FnMut(Acc, Self::Item) -> Acc,
330+
{
331+
let data = &mut self.data;
332+
iter::ByRefSized(&mut self.alive).rfold(init, |acc, idx| {
333+
// SAFETY: idx is obtained by folding over the `alive` range, which implies the
334+
// value is currently considered alive but as the range is being consumed each value
335+
// we read here will only be read once and then considered dead.
336+
rfold(acc, unsafe { data.get_unchecked(idx).assume_init_read() })
337+
})
338+
}
339+
326340
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
327341
let len = self.len();
328342

library/core/src/iter/adapters/by_ref_sized.rs

+40
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,35 @@ pub(crate) struct ByRefSized<'a, I>(pub &'a mut I);
99
impl<I: Iterator> Iterator for ByRefSized<'_, I> {
1010
type Item = I::Item;
1111

12+
#[inline]
1213
fn next(&mut self) -> Option<Self::Item> {
1314
self.0.next()
1415
}
1516

17+
#[inline]
1618
fn size_hint(&self) -> (usize, Option<usize>) {
1719
self.0.size_hint()
1820
}
1921

22+
#[inline]
2023
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
2124
self.0.advance_by(n)
2225
}
2326

27+
#[inline]
2428
fn nth(&mut self, n: usize) -> Option<Self::Item> {
2529
self.0.nth(n)
2630
}
2731

32+
#[inline]
2833
fn fold<B, F>(self, init: B, f: F) -> B
2934
where
3035
F: FnMut(B, Self::Item) -> B,
3136
{
3237
self.0.fold(init, f)
3338
}
3439

40+
#[inline]
3541
fn try_fold<B, F, R>(&mut self, init: B, f: F) -> R
3642
where
3743
F: FnMut(B, Self::Item) -> R,
@@ -40,3 +46,37 @@ impl<I: Iterator> Iterator for ByRefSized<'_, I> {
4046
self.0.try_fold(init, f)
4147
}
4248
}
49+
50+
impl<I: DoubleEndedIterator> DoubleEndedIterator for ByRefSized<'_, I> {
51+
#[inline]
52+
fn next_back(&mut self) -> Option<Self::Item> {
53+
self.0.next_back()
54+
}
55+
56+
#[inline]
57+
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
58+
self.0.advance_back_by(n)
59+
}
60+
61+
#[inline]
62+
fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
63+
self.0.nth_back(n)
64+
}
65+
66+
#[inline]
67+
fn rfold<B, F>(self, init: B, f: F) -> B
68+
where
69+
F: FnMut(B, Self::Item) -> B,
70+
{
71+
self.0.rfold(init, f)
72+
}
73+
74+
#[inline]
75+
fn try_rfold<B, F, R>(&mut self, init: B, f: F) -> R
76+
where
77+
F: FnMut(B, Self::Item) -> R,
78+
R: Try<Output = B>,
79+
{
80+
self.0.try_rfold(init, f)
81+
}
82+
}

library/core/tests/array.rs

+32
Original file line numberDiff line numberDiff line change
@@ -668,3 +668,35 @@ fn array_mixed_equality_nans() {
668668
assert!(!(mut3 == array3));
669669
assert!(mut3 != array3);
670670
}
671+
672+
#[test]
673+
fn array_into_iter_fold() {
674+
// Strings to help MIRI catch if we double-free or something
675+
let a = ["Aa".to_string(), "Bb".to_string(), "Cc".to_string()];
676+
let mut s = "s".to_string();
677+
a.into_iter().for_each(|b| s += &b);
678+
assert_eq!(s, "sAaBbCc");
679+
680+
let a = [1, 2, 3, 4, 5, 6];
681+
let mut it = a.into_iter();
682+
it.advance_by(1).unwrap();
683+
it.advance_back_by(2).unwrap();
684+
let s = it.fold(10, |a, b| 10 * a + b);
685+
assert_eq!(s, 10234);
686+
}
687+
688+
#[test]
689+
fn array_into_iter_rfold() {
690+
// Strings to help MIRI catch if we double-free or something
691+
let a = ["Aa".to_string(), "Bb".to_string(), "Cc".to_string()];
692+
let mut s = "s".to_string();
693+
a.into_iter().rev().for_each(|b| s += &b);
694+
assert_eq!(s, "sCcBbAa");
695+
696+
let a = [1, 2, 3, 4, 5, 6];
697+
let mut it = a.into_iter();
698+
it.advance_by(1).unwrap();
699+
it.advance_back_by(2).unwrap();
700+
let s = it.rfold(10, |a, b| 10 * a + b);
701+
assert_eq!(s, 10432);
702+
}

src/test/codegen/simd-wide-sum.rs

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// compile-flags: -C opt-level=3 --edition=2021
2+
// only-x86_64
3+
// ignore-debug: the debug assertions get in the way
4+
5+
#![crate_type = "lib"]
6+
#![feature(portable_simd)]
7+
8+
use std::simd::Simd;
9+
const N: usize = 8;
10+
11+
#[no_mangle]
12+
// CHECK-LABEL: @wider_reduce_simd
13+
pub fn wider_reduce_simd(x: Simd<u8, N>) -> u16 {
14+
// CHECK: zext <8 x i8>
15+
// CHECK-SAME: to <8 x i16>
16+
// CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
17+
let x: Simd<u16, N> = x.cast();
18+
x.reduce_sum()
19+
}
20+
21+
#[no_mangle]
22+
// CHECK-LABEL: @wider_reduce_loop
23+
pub fn wider_reduce_loop(x: Simd<u8, N>) -> u16 {
24+
// CHECK: zext <8 x i8>
25+
// CHECK-SAME: to <8 x i16>
26+
// CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
27+
let mut sum = 0_u16;
28+
for i in 0..N {
29+
sum += u16::from(x[i]);
30+
}
31+
sum
32+
}
33+
34+
#[no_mangle]
35+
// CHECK-LABEL: @wider_reduce_iter
36+
pub fn wider_reduce_iter(x: Simd<u8, N>) -> u16 {
37+
// CHECK: zext <8 x i8>
38+
// CHECK-SAME: to <8 x i16>
39+
// CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
40+
x.as_array().iter().copied().map(u16::from).sum()
41+
}
42+
43+
// This iterator one is the most interesting, as it's the one
44+
// which used to not auto-vectorize due to a suboptimality in the
45+
// `<array::IntoIter as Iterator>::fold` implementation.
46+
47+
#[no_mangle]
48+
// CHECK-LABEL: @wider_reduce_into_iter
49+
pub fn wider_reduce_into_iter(x: Simd<u8, N>) -> u16 {
50+
// CHECK: zext <8 x i8>
51+
// CHECK-SAME: to <8 x i16>
52+
// CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
53+
x.to_array().into_iter().map(u16::from).sum()
54+
}

0 commit comments

Comments
 (0)